Setup

library(tidyverse)
## ── Attaching core tidyverse packages ──────────────────────────────────────────── tidyverse 2.0.0 ──
## ✔ dplyr     1.1.2     ✔ readr     2.1.4
## ✔ forcats   1.0.0     ✔ stringr   1.5.0
## ✔ ggplot2   3.4.3     ✔ tibble    3.2.1
## ✔ lubridate 1.9.2     ✔ tidyr     1.3.0
## ✔ purrr     1.0.2     
## ── Conflicts ────────────────────────────────────────────────────────────── tidyverse_conflicts() ──
## ✖ dplyr::filter() masks stats::filter()
## ✖ dplyr::lag()    masks stats::lag()
## ℹ Use the conflicted package (<http://conflicted.r-lib.org/>) to force all conflicts to become errors
library(here)
## here() starts at /Users/jaeyoungson/Documents/GitHub/network-navigation-replay
library(glmmTMB)
## Warning in checkMatrixPackageVersion(): Package version inconsistency detected.
## TMB was built with Matrix version 1.6.0
## Current Matrix version is 1.6.1
## Please re-install 'TMB' from source using install.packages('TMB', type = 'source') or ask CRAN for a binary version of 'TMB' matching CRAN's 'Matrix' package
## Warning in checkDepPackageVersion(dep_pkg = "TMB"): Package version inconsistency detected.
## glmmTMB was built with TMB version 1.9.4
## Current TMB version is 1.9.6
## Please re-install glmmTMB from source or restore original 'TMB' package (see '?reinstalling' for more information)
library(broom.mixed)

library(tidygraph)
## 
## Attaching package: 'tidygraph'
## 
## The following object is masked from 'package:stats':
## 
##     filter
library(ggraph)

library(kableExtra)
## 
## Attaching package: 'kableExtra'
## 
## The following object is masked from 'package:dplyr':
## 
##     group_rows
library(tictoc)
source(here("code", "utils", "ggplot_themes.R"))
source(here("code", "utils", "modeling_utils.R"))
source(here("code", "utils", "representation_utils.R"))
source(here("code", "utils", "unicode_greek.R"))

predict_glmmTMB <- function(make_predictions_for, model_object) {
  make_predictions_for %>%
    bind_cols(
      predict(
        object = model_object,
        newdata = .,
        re.form = NA, allow.new.levels = TRUE, se.fit = TRUE, type = "response"
      )
    )
}

check_significance <- function(tidy_df) {
  tidy_df %>%
    mutate(
      sig = case_when(
        p.value < 0.001 ~ "***",
        p.value < 0.01 ~ "**",
        p.value < 0.05 ~ "*",
        p.value < 0.1 ~ ".",
        TRUE ~ ""
      )
    )
}
# To control when to save figures
knitting <- knitr::is_html_output()

if (knitting) {
  if (!dir.exists(here("figures"))) {
    dir.create(here("figures"))
  }
}

Network visualizations

adjlist <- here("data", "clean-data", "adjlist_learned.csv") %>%
  read_csv(show_col_types = FALSE)

g <- adjlist %>%
  filter(from < to, edge == 1) %>%
  select(-edge) %>%
  tbl_graph(edges = ., directed = FALSE)

adjlist_reevaluated <- here("data", "clean-data", "adjlist_reevaluated.csv") %>%
  read_csv(show_col_types = FALSE)

g_reevaluated <- adjlist_reevaluated %>%
  filter(from < to, edge == 1) %>%
  select(-edge) %>%
  tbl_graph(edges = ., directed = FALSE)
plot_network_learned <- g %>%
  mutate(name = row_number()) %>%
  ggraph("stress") +
  theme_network() +
  geom_edge_link() +
  geom_node_label(aes(label = name))

plot_network_reevaluated <- g_reevaluated %>%
  mutate(name = row_number()) %>%
  ggraph("stress") +
  theme_network() +
  geom_edge_link() +
  geom_node_label(aes(label = name))

plot_network_learned
## Warning: Using the `size` aesthetic in this geom was deprecated in ggplot2 3.4.0.
## ℹ Please use `linewidth` in the `default_aes` field and elsewhere instead.
## This warning is displayed once every 8 hours.
## Call `lifecycle::last_lifecycle_warnings()` to see where this warning was generated.

plot_network_reevaluated

if (knitting) {
  ggsave(
    here("figures", str_c("network_learned", ".pdf")),
    plot = plot_network_learned,
    width = 4, height = 2,
    units = "in", dpi = 300
  )
  
  ggsave(
    here("figures", str_c("network_reevaluated", ".pdf")),
    plot = plot_network_reevaluated,
    width = 4, height = 2,
    units = "in", dpi = 300
  )
}

Can humans solve social navigation problems?

In the paper, we start by examining social navigation behaviors in a one-day session (Study 1), or in the first session of a two-day study (Studies 2-3). Across all three studies, the procedure is exactly identical; studies 2-3 are, in this part of the dataset, exact replications of study 1.

nav_study1 <- here("data", "clean-data", "study1_message_passing.csv") %>%
  read_csv(show_col_types = FALSE) %>%
  filter(
    two_correct_options == FALSE,
    shortest_path_given_opts == shortest_path_given_start_end
  ) %>%
  mutate(
    study = "Study 1",
    measurement_id = str_c("D", measurement_id),
    shortest_path = factor(shortest_path_given_opts)
  ) %>%
  select(
    study, sub_id, measurement_id, shortest_path,
    startpoint_id, endpoint_id,
    opt1_id, opt2_id,
    correct_choice, sub_choice,
    correct, rt
  )

nav_study2 <- here("data", "clean-data", "study2_message_passing.csv") %>%
  read_csv(show_col_types = FALSE) %>%
  filter(
    two_correct_options == FALSE,
    shortest_path_given_opts == shortest_path_given_start_end
  ) %>%
  mutate(
    study = "Study 2",
    measurement_id = case_when(
      network == "learned" ~ str_c("D", measurement_id),
      network == "reevaluated" ~ "D2b"
    ),
    shortest_path = factor(shortest_path_given_opts)
  ) %>%
  select(
    study, sub_id, measurement_id, shortest_path,
    startpoint_id, endpoint_id,
    opt1_id, opt2_id,
    correct_choice, sub_choice,
    correct, rt
  )

nav_study3 <- here("data", "clean-data", "study3_message_passing.csv") %>%
  read_csv(show_col_types = FALSE) %>%
  filter(
    two_correct_options == FALSE,
    shortest_path_given_opts == shortest_path_given_start_end
  ) %>%
  mutate(
    study = "Study 3",
    measurement_id = case_when(
      network == "reevaluated" ~ "D2b",
      measurement_id == 1 ~ "D1",
      measurement_id == 2 ~ "D1b",
      measurement_id == 3 ~ "D2"
    ),
    shortest_path = factor(shortest_path_given_opts)
  ) %>%
  select(
    study, sub_id, measurement_id, shortest_path,
    startpoint_id, endpoint_id,
    opt1_id, opt2_id,
    correct_choice, sub_choice,
    correct, rt
  )

We’ll start off with some descriptive statistics of human behavior. To maximize statistical power, we will pool across studies whenever possible.

bind_rows(nav_study1, nav_study2, nav_study3) %>%
  filter(measurement_id == "D1") %>%
  group_by(measurement_id, shortest_path) %>%
  summarise(accuracy = mean(correct), .groups = "drop") %>%
  arrange(measurement_id, shortest_path) %>%
  pivot_wider(
    names_from = shortest_path, values_from = accuracy, names_prefix = "dist-"
  ) %>%
  kbl(
    caption = str_c(
      "<center>", "Descriptive: Navigation accuracy", "</center>"
    ),
    digits = 2
  ) %>%
  kable_styling(bootstrap_options = c("responsive"))
Descriptive: Navigation accuracy
measurement_id dist-2 dist-3 dist-4
D1 0.8 0.7 0.63

And now the inferential statistical tests. Note that we’re interested in knowing whether navigation accuracy differs from chance at each distance, so we’ll estimate the same statistical model multiple times, changing the reference category each time. Note that this only reparameterizes the model, such that the same variance is accounted for by different parameters; it does not change the total amount of variance accounted for.

nav_day1 <- bind_rows(nav_study1, nav_study2, nav_study3) %>%
  filter(measurement_id == "D1") %>%
  # Give every subject a distinct identifier
  mutate(sub_id = str_c(study, " s", sub_id))

stats_nav_day1_dist2 <- nav_day1 %>%
  mutate(shortest_path = fct_relevel(shortest_path, "2")) %>%
  glmmTMB(
    correct ~ shortest_path + (1 + shortest_path | sub_id) + (1 | study),
    family = binomial,
    data = .
  )

stats_nav_day1_dist3 <- nav_day1 %>%
  mutate(shortest_path = fct_relevel(shortest_path, "3")) %>%
  glmmTMB(
    correct ~ shortest_path + (1 + shortest_path | sub_id) + (1 | study),
    family = binomial,
    data = .
  )

stats_nav_day1_dist4 <- nav_day1 %>%
  mutate(shortest_path = fct_relevel(shortest_path, "4")) %>%
  glmmTMB(
    correct ~ shortest_path + (1 + shortest_path | sub_id) + (1 | study),
    family = binomial,
    data = .
  )

map_dfr(
  .x = list(
    "dist-2" = stats_nav_day1_dist2,
    "dist-3" = stats_nav_day1_dist3,
    "dist-4" = stats_nav_day1_dist4
  ),
  .f = ~tidy(.x, conf.int = TRUE),
  .id = "ref_cat"
) %>%
  check_significance() %>%
  select(-c(ref_cat, effect, component)) %>%
  kbl(
    caption = str_c("<center>", "Navigation accuracy: Day 1", "</center>"),
    digits = 3
  ) %>%
  kable_styling(bootstrap_options = c("responsive")) %>%
  pack_rows("Ref. Cat. dist-2", 1, 10) %>%
  pack_rows("Ref. Cat. dist-3", 11, 20) %>%
  pack_rows("Ref. Cat. dist-4", 21, 30)
Navigation accuracy: Day 1
group term estimate std.error statistic p.value conf.low conf.high sig
Ref. Cat. dist-2
NA (Intercept) 1.682 0.118 14.303 0 1.451 1.912 ***
NA shortest_path3 -0.622 0.066 -9.428 0 -0.751 -0.493 ***
NA shortest_path4 -1.026 0.081 -12.710 0 -1.184 -0.868 ***
sub_id sd__(Intercept) 1.041 NA NA NA 0.897 1.208
sub_id sd__shortest_path3 0.437 NA NA NA 0.320 0.596
sub_id sd__shortest_path4 0.748 NA NA NA 0.624 0.897
sub_id cor__(Intercept).shortest_path3 -0.392 NA NA NA -0.645 -0.008
sub_id cor__(Intercept).shortest_path4 -0.538 NA NA NA -0.884 0.179
sub_id cor__shortest_path3.shortest_path4 0.944 NA NA NA 0.236 0.976
study sd__(Intercept) 0.122 NA NA NA 0.019 0.765
Ref. Cat. dist-3
NA (Intercept) 1.059 0.114 9.335 0 0.837 1.282 ***
NA shortest_path2 0.622 0.066 9.428 0 0.493 0.751 ***
NA shortest_path4 -0.404 0.062 -6.549 0 -0.525 -0.283 ***
sub_id sd__(Intercept) 0.958 NA NA NA 0.807 1.137
sub_id sd__shortest_path2 0.437 NA NA NA 0.320 0.596
sub_id sd__shortest_path4 0.366 NA NA NA 0.243 0.549
sub_id cor__(Intercept).shortest_path2 -0.030 NA NA NA -0.398 0.351
sub_id cor__(Intercept).shortest_path4 -0.352 NA NA NA -0.364 0.248
sub_id cor__shortest_path2.shortest_path4 -0.737 NA NA NA -0.688 0.874
study sd__(Intercept) 0.122 NA NA NA 0.019 0.765
Ref. Cat. dist-4
NA (Intercept) 0.655 0.108 6.049 0 0.443 0.868 ***
NA shortest_path2 1.026 0.081 12.709 0 0.868 1.184 ***
NA shortest_path3 0.404 0.062 6.549 0 0.283 0.525 ***
sub_id sd__(Intercept) 0.897 NA NA NA 0.765 1.053
sub_id sd__shortest_path2 0.748 NA NA NA 0.624 0.897
sub_id sd__shortest_path3 0.366 NA NA NA 0.243 0.549
sub_id cor__(Intercept).shortest_path2 -0.209 NA NA NA -0.432 0.051
sub_id cor__(Intercept).shortest_path3 -0.031 NA NA NA -0.472 0.151
sub_id cor__shortest_path2.shortest_path3 0.919 NA NA NA -0.524 0.986
study sd__(Intercept) 0.122 NA NA NA 0.019 0.765

We’ll plot out the raw data, plus model predictions…

predict_nav_day1 <- expand_grid(
  measurement_id = "D1",
  shortest_path = factor(2:4),
  sub_id = NA, study = NA
) %>%
  predict_glmmTMB(stats_nav_day1_dist2)

plot_nav_day1 <- bind_rows(nav_study2, nav_study3) %>%
  filter(measurement_id == "D1") %>%
  group_by(sub_id, measurement_id, shortest_path) %>%
  summarise(accuracy = mean(correct), .groups = "drop") %>%
  ggplot(aes(x=shortest_path, y=accuracy)) +
  theme_custom() +
  geom_hline(yintercept = 0.5, linetype = "dashed") +
  geom_dotplot(
    binwidth = 0.01,
    binaxis = "y", stackdir = "center",
    position = position_dodge(width = 0.75),
    dotsize = 1, alpha = 0.5, color = NA,
    show.legend = FALSE
  ) +
  geom_pointrange(
    aes(x = shortest_path, y = fit, ymin = fit - se.fit, ymax = fit + se.fit),
    data = predict_nav_day1, inherit.aes = FALSE, show.legend = FALSE,
    position = position_dodge(width = 0.15), linewidth = 1
  ) +
  geom_line(
    aes(x = shortest_path, y = fit, group = measurement_id),
    data = predict_nav_day1, inherit.aes = FALSE,
    position = position_dodge(width = 0.15), linewidth = 1
  ) +
  scale_x_discrete(name = "Shortest path distance") +
  scale_y_continuous(
    name = "Accuracy", labels = scales::percent, breaks = seq(0, 1, 0.1)
  ) +
  coord_cartesian(ylim = c(0.4, 1.1)) +
  theme(legend.position = "bottom") +
  ggtitle("Human social navigation")

plot_nav_day1

if (knitting) {
  ggsave(
    here("figures", str_c("navigation_day1", ".pdf")),
    plot = plot_nav_day1,
    width = 8/3, height = 3,
    units = "in", dpi = 300
  )
}

Computational models of social navigation

We’d like to have some mechanistic insights about how people solve navigation problems. To do this, we’ll look at two candidate models of navigation: breadth-first search (BFS) and the Successor Representation (SR).

Simulations

BFS simulation

sim_bfs <- here("data", "bfs-sims", "bfs_sims_learned.csv") %>%
  read_csv(show_col_types = FALSE) %>%
  filter(
    two_correct_options == FALSE,
    shortest_path_given_opts == shortest_path_given_start_end
  ) %>%
  mutate(shortest_path = factor(shortest_path_given_opts)) %>%
  select(
    shortest_path, startpoint_id, endpoint_id, opt1_id, opt2_id,
    bfs_choice, bfs_correct_choice, bfs_n_visits_total
  )

In our implementation, we model the BFS agent having some “threshold” for searching through the network. This can be thought of as a “willingness to spend time/effort performing a search” threshold. Once that threshold is exceeded, it becomes increasingly likely that the agent gives up and chooses randomly.

To see what threshold values might be informative to look at, we’ll look at the average number of “searches” that an agent must perform to make a (non-random) decision.

sim_bfs %>%
  group_by(shortest_path) %>%
  summarise(avg_n_searches = mean(bfs_n_visits_total)) %>%
  kbl(
    caption = str_c("<center>", "Average searches in online-BFS", "</center>"),
    digits = 2
  ) %>%
  kable_styling(bootstrap_options = c("responsive"))
Average searches in online-BFS
shortest_path avg_n_searches
2 4.92
3 7.70
4 11.16

And now we’ll plot the model predictions…

bfs_avg_accuracy <- sim_bfs %>%
  group_by(shortest_path, startpoint_id, endpoint_id, opt1_id, opt2_id) %>%
  summarise(
    bfs_accuracy = mean(bfs_correct_choice),
    bfs_visits = mean(bfs_n_visits_total),
    .groups = "drop"
  )

plot_sim_bfs <- bfs_avg_accuracy %>%
  expand_grid(search_threshold = seq(2, 12, 2)) %>%
  rowwise() %>%
  mutate(
    # Note: p(BFS) is 1-p(give up)
    p_bfs = softmax(
      option_values = c(search_threshold, bfs_visits),
      option_chosen = 1,
      temperature = 1
    )
  ) %>%
  ungroup() %>%
  # Weigh the model predictions according to their likelihood
  mutate(
    p_give_up = 1 - p_bfs,
    bfs_threshold_accuracy = (p_bfs * bfs_accuracy) + (p_give_up * (1/2))
  ) %>%
  # Format for plotting
  mutate(
    search_threshold = str_pad(search_threshold, width = 2, side = "left")
  ) %>%
  bind_rows(
    bfs_avg_accuracy %>%
      mutate(
        search_threshold = "Never gives up",
        bfs_threshold_accuracy = bfs_accuracy
      )
  ) %>%
  # Now plot
  ggplot(
    aes(
      x=shortest_path, y=bfs_threshold_accuracy,
      color=search_threshold, group=search_threshold
    )
  ) +
  theme_custom() +
  geom_hline(yintercept = 0.5, linetype = "dashed") +
  stat_summary(geom = "line", fun = mean, linewidth = 1) +
  scale_x_discrete(name = "Shortest path distance") +
  scale_y_continuous(name = "Accuracy", labels = scales::percent) +
  scale_color_viridis_d(
    name = "Search threshold", option = "magma",
    begin = 0.1, end = 0.9, direction = -1
  ) +
  guides(color = guide_legend(byrow = TRUE, nrow = 1)) +
  coord_cartesian(ylim = c(0.5, 1)) +
  theme(
    panel.grid = element_blank(),
    legend.position = "bottom"
  ) +
  ggtitle("Simulated BFS navigation")

plot_sim_bfs

if (knitting) {
  ggsave(
    here("figures", str_c("simulated_bfs", ".pdf")),
    plot = plot_sim_bfs,
    width = 8/3, height = 3,
    units = "in", dpi = 300
  )
}

SR navigation simulation

triallist_nav_learned <- nav_study1 %>%
  filter(sub_id == 1) %>%
  select(
    startpoint_id, endpoint_id, opt1_id, opt2_id,
    correct_choice, shortest_path
  ) %>%
  arrange(startpoint_id, endpoint_id)

We’ll create some simulated learning observations.

set.seed(sum(utf8ToInt("Jenny and me was like peas and carrots")))

simulated_paired_associates <- adjlist %>%
  filter(edge == 1) %>%
  select(from, to) %>%
  expand_grid(set = 1:5000, .) %>%
  group_by(set) %>%
  slice_sample(prop = 1) %>%
  ungroup()

And now we’ll simulate an “asymptotic” SR and how it performs in the navigation task.

simulate_sr <- function(simulated_observations) {
  simulated_observations %>%
    expand_grid(gamma = seq(0.1, 0.9, 0.1)) %>%
    group_by(gamma) %>%
    nest() %>%
    mutate(
      sr = map(
        .x = data,
        .f = ~build_rep_sr(
          learning_data = .x, this_alpha = 0.1, this_gamma = gamma
        )
      )
    ) %>%
    unnest(sr) %>%
    ungroup() %>%
    select(-data)
}

join_sr <- function(navigation_triallist, simulated_sr) {
  navigation_triallist %>%
    left_join(
      simulated_sr %>%
        rename(opt1_sr = sr_value, opt1_id = from, endpoint_id = to)
    ) %>%
    left_join(
      simulated_sr %>%
        rename(opt2_sr = sr_value, opt2_id = from, endpoint_id = to)
    )
}

sim_sr_matrix_asymptotic <- simulated_paired_associates %>%
  simulate_sr()

sim_sr_behavior_asymptotic <- join_sr(
  triallist_nav_learned, sim_sr_matrix_asymptotic
) %>%
  # Feed values through softmax
  mutate(across(c(opt1_sr, opt2_sr), ~.x * 100)) %>%
  expand_grid(temperature = 1) %>%
  rowwise() %>%
  mutate(
    p_correct = softmax(
      option_values = c(opt1_sr, opt2_sr),
      option_chosen = if_else(correct_choice == opt1_id, 1, 2),
      temperature = temperature,
      use_inverse_temperature = TRUE
    )
  ) %>%
  ungroup()
## Joining with `by = join_by(endpoint_id, opt1_id)`
## Warning in left_join(., simulated_sr %>% rename(opt1_sr = sr_value, opt1_id = from, : Detected an unexpected many-to-many relationship between `x` and `y`.
## ℹ Row 1 of `x` matches multiple rows in `y`.
## ℹ Row 45 of `y` matches multiple rows in `x`.
## ℹ If a many-to-many relationship is expected, set `relationship = "many-to-many"` to silence this
##   warning.
## Joining with `by = join_by(endpoint_id, opt2_id, gamma)`
plot_sim_sr_asymptotic <- sim_sr_behavior_asymptotic %>%
  mutate(
    gamma = factor(gamma),
    temperature = str_pad(temperature, width = 2, side = "left"),
    temperature = str_c(unicode_greek["tau"], " = ", temperature)
  ) %>%
  group_by(gamma, temperature, shortest_path) %>%
  summarise(p_correct = mean(p_correct), .groups = "drop") %>%
  ggplot(aes(x=shortest_path, y=p_correct, color=gamma)) +
  theme_custom() +
  facet_grid(rows = vars(temperature)) +
  geom_hline(yintercept = 0.5, linetype = "dashed") +
  geom_line(aes(color = gamma, group = gamma), linewidth = 0.8) +
  scale_x_discrete(name = "Shortest path distance") +
  scale_y_continuous(name = "Accuracy", labels = scales::percent) +
  scale_color_viridis_d(
    name = str_c(unicode_greek["gamma"], " = "), option = "turbo", end = 0.9
  ) +
  guides(color = guide_legend(byrow = TRUE, nrow = 1)) +
  coord_cartesian(ylim = c(0.5, 1)) +
  theme(legend.position = "bottom") +
  ggtitle("Simulated SR navigation")

plot_sim_sr_asymptotic

if (knitting) {
  suppressWarnings(
    ggsave(
      here("figures", str_c("simulated_sr", ".pdf")),
      plot = plot_sim_sr_asymptotic,
      width = 8/3, height = 3,
      units = "in", dpi = 300,
    )
  )
  
  ggsave(
    here("figures", str_c("simulated_sr_cairo", ".pdf")),
    plot = plot_sim_sr_asymptotic +
      guides(color = guide_legend(byrow = TRUE, nrow = 2)),
    width = 8/3, height = 3,
    units = "in", dpi = 300,
    device = cairo_pdf
  )
}

Model comparison

To do model comparison, we’ll need to pull in the estimated parameters/likelihoods from the model-fitting.

clean_params_from_raw <- FALSE

if (clean_params_from_raw) {
  params <- here("data", "param-fits") %>%
    fs::dir_ls(regexp = "study[[:digit:]]_D[[:digit:]]b?") %>%
    map_dfr(
      .f = ~read_csv(.x, show_col_types = FALSE) %>%
        best_optim_run("dataframe"),
      .id = "filename"
    ) %>%
    mutate(
      study = str_extract(filename, "study[[:digit:]]"),
      study = str_remove(study, "study"),
      study = str_c("Study ", study),
      measurement_id = str_extract(filename, "_D[[:digit:]]b?"),
      measurement_id = str_remove(measurement_id, "_"),
      sub_id = str_extract(filename, "sub-[[:digit:]]+"),
      sub_id = str_remove(sub_id, "sub-"),
      sub_id = as.integer(sub_id),
      model = case_when(
        str_detect(filename, "_hybrid-") ~ "hybrid",
        str_detect(filename, "_sr-") ~ "sr",
        str_detect(filename, "_bfs-") ~ "bfs"
      )
    ) %>%
    select(
      study, sub_id, measurement_id,
      model,
      param_name, param_value = param_value_human_readable,
      neg_loglik = optim_value
    ) %>%
    arrange(study, sub_id, measurement_id, model)
  
  params %>%
    write_csv(here("data", "param-fits", "clean_param_fits.csv"))
} else {
  params <- here("data", "param-fits", "clean_param_fits.csv") %>%
    read_csv(show_col_types = FALSE)
}
source(here("code", "utils", "bayesian_model_selection.R"))

pxp_results <- params %>%
  select(study, sub_id, measurement_id, model, neg_loglik) %>%
  distinct() %>%
  mutate(log_lik = -neg_loglik) %>%
  select(-neg_loglik) %>%
  pivot_wider(names_from = model, values_from = log_lik) %>%
  select(-sub_id) %>%
  group_by(study, measurement_id) %>%
  nest() %>%
  mutate(
    test = map(
      .x = data,
      .f = ~bayesian_model_selection(.x)
    )
  ) %>%
  unnest(test) %>%
  ungroup() %>%
  select(-data)

pxp_results %>%
  ggplot(aes(x=measurement_id, y=pxp, color=model)) +
  theme_custom() +
  facet_wrap(~study, scales = "free_x") +
  geom_point(
    position = position_dodge(width = 0.75)
  )

pxp_results %>%
  select(study, measurement_id, model, pxp) %>%
  mutate(pxp = as.numeric(pxp)) %>%
  pivot_wider(
    names_from = model,
    values_from = pxp,
    names_prefix = "pxp_"
  ) %>%
  kbl(
    caption = str_c("<center>", "PXP results", "</center>"),
    digits = 3
  ) %>%
  kable_styling(bootstrap_options = c("responsive"))
PXP results
study measurement_id pxp_bfs pxp_hybrid pxp_sr
Study 1 D1 0 0.000 1.000
Study 2 D1 0 0.000 1.000
Study 2 D2 0 0.000 1.000
Study 3 D1 0 0.000 0.999
Study 3 D1b 0 0.000 1.000
Study 3 D2 0 0.006 0.994

Posterior predictive check

We’ll start by simulating the model’s predictions for each subject, given their estimated parameters.

ppc_bfs <- bind_rows(nav_study1, nav_study2, nav_study3) %>%
  filter(measurement_id %in% c("D1", "D2")) %>%
  left_join(
    params %>%
      filter(model == "bfs") %>%
      pivot_wider(names_from = param_name, values_from = param_value) %>%
      select(study, sub_id, measurement_id, search_threshold, lapse_rate)
  ) %>%
  left_join(bfs_avg_accuracy) %>%
  mutate(
    p_sub_choice_bfs = if_else(
      sub_choice == correct_choice,
      bfs_accuracy,
      1 - bfs_accuracy
    )
  ) %>%
  # What's the probability of *completing* BFS-online all the way through?
  rowwise() %>%
  mutate(
    search_threshold = search_threshold,
    p_complete_bfs = softmax(
      option_values = c(search_threshold, bfs_visits),
      option_chosen = 1,
      temperature = 1
    )
  ) %>%
  ungroup() %>%
  # Weigh BFS predictions accordingly
  mutate(
    p_give_up = 1 - p_complete_bfs,
    predicted_correct = (p_complete_bfs * p_sub_choice_bfs) + (p_give_up * 1/2),
    ### Add lapse rate
    #   Dividing by 2 is because there are two options to choose from
    #   Therefore, when lapse rate = 1, this becomes chance = 1/2
    predicted_correct = predicted_correct * (1-lapse_rate) + (lapse_rate/2)
  ) %>%
  # Average over trials
  group_by(study, sub_id, measurement_id, shortest_path) %>%
  summarise(
    empirical = mean(correct),
    predicted = mean(predicted_correct),
    .groups = "drop"
  )
## Joining with `by = join_by(study, sub_id, measurement_id)`
## Joining with `by = join_by(shortest_path, startpoint_id, endpoint_id, opt1_id, opt2_id)`
ppc_sr_matrix <- params %>%
  filter(model == "sr") %>%
  select(
    study, sub_id, measurement_id, name = param_name, value = param_value
  ) %>%
  pivot_wider() %>%
  group_by(study, sub_id, measurement_id) %>%
  nest() %>%
  mutate(
    sim_sr = map(
      .x = data,
      .f = ~build_rep_sr(
        learning_data = simulated_paired_associates %>% filter(set %in% 1:100),
        this_alpha = 0.1,
        this_gamma = .x$sr_gamma,
        bidirectional = TRUE
      )
    )
  ) %>%
  unnest(sim_sr) %>%
  unnest(data) %>%
  ungroup()

ppc_sr_behavior <- bind_rows(nav_study1, nav_study2, nav_study3) %>%
  filter(measurement_id %in% c("D1", "D2")) %>%
  left_join(
    ppc_sr_matrix %>%
      rename(opt1_id = from, endpoint_id = to, opt1_sr = sr_value)
  ) %>%
  left_join(
    ppc_sr_matrix %>%
      rename(opt2_id = from, endpoint_id = to, opt2_sr = sr_value)
  ) %>%
  rowwise() %>%
  mutate(
    predicted_correct = softmax(
      option_values = c(opt1_sr, opt2_sr) * 100,
      option_chosen = if_else(correct_choice == opt1_id, 1, 2),
      temperature = softmax_temperature,
      use_inverse_temperature = TRUE,
      lapse_rate = lapse_rate
    )
  ) %>%
  ungroup() %>%
  # Fix trials when the softmax becomes undefined
  mutate(
    predicted_correct = case_when(
      is.nan(predicted_correct) & (sub_choice == correct_choice) ~ 1,
      is.nan(predicted_correct) & (sub_choice != correct_choice) ~ 0,
      TRUE ~ predicted_correct
    )
  ) %>%
  # Average over trials
  group_by(study, sub_id, measurement_id, shortest_path) %>%
  summarise(
    empirical = mean(correct),
    predicted = mean(predicted_correct),
    .groups = "drop"
  )
## Joining with `by = join_by(study, sub_id, measurement_id, endpoint_id, opt1_id)`
## Joining with `by = join_by(study, sub_id, measurement_id, endpoint_id, opt2_id, sr_gamma,
## softmax_temperature, lapse_rate)`

Now that we have both PPCs, we can plot them side-by-side.

plot_ppc_day1 <- ppc_bfs %>%
  rename(predicted_bfs = predicted) %>%
  left_join(ppc_sr_behavior %>% rename(predicted_sr = predicted)) %>%
  pivot_longer(
    c(empirical, predicted_bfs, predicted_sr),
    names_to = "agent", values_to = "accuracy"
  ) %>%
  filter(measurement_id == "D1") %>%
  ggplot(aes(x=shortest_path, y=accuracy)) +
  theme_custom() +
  geom_hline(yintercept = 0.5, linetype = "dashed") +
  geom_point(
    aes(color = agent),
    alpha = 0.05,
    position = position_jitterdodge(
      jitter.width = 0.2, jitter.height = 0, dodge.width = 0.75, seed = 1
    ),
    show.legend = FALSE
  ) +
  stat_summary(
    aes(color = agent), geom = "crossbar", fun = mean,
    position = position_dodge(0.5)
  ) +
  scale_x_discrete(name = "Shortest path distance") +
  scale_y_continuous(
    name = "Accuracy", labels = scales::percent, breaks = seq(0, 1, 0.25)
  ) +
  scale_color_manual(
    name = NULL,
    values = c(
      "empirical"="#fd8d3c",
      "predicted_bfs"="#8c2d04",
      "predicted_sr"="#bd0026"
    ),
    labels = c("empirical"="Human", "predicted_bfs"="BFS", "predicted_sr"="SR")
  ) +
  theme(legend.position = "bottom") +
  ggtitle("Posterior predictive check")
## Joining with `by = join_by(study, sub_id, measurement_id, shortest_path,
## empirical)`
plot_ppc_day1

if (knitting) {
  ggsave(
    here("figures", str_c("ppc_day1", ".pdf")),
    plot = plot_ppc_day1,
    width = 8/2, height = 2.5,
    units = "in", dpi = 300
  )
}

Held-out trials

There were some trials where both Sources had equivalent shortest path distances from the Target. We would expect that BFS would be largely indifferent between the two Sources, but it is possible that humans and/or the SR would make other predictions.

nav_study1_ties <- here("data", "clean-data", "study1_message_passing.csv") %>%
  read_csv(show_col_types = FALSE) %>%
  filter(
    two_correct_options == TRUE,
    shortest_path_given_opts == shortest_path_given_start_end
  ) %>%
  mutate(
    study = "Study 1",
    measurement_id = str_c("D", measurement_id),
    shortest_path = factor(shortest_path_given_opts)
  ) %>%
  select(
    study, sub_id, measurement_id, shortest_path,
    startpoint_id, endpoint_id,
    opt1_id, opt2_id,
    correct_choice, sub_choice,
    correct, rt
  ) %>%
  filter(measurement_id %in% c("D1", "D2"))

nav_study2_ties <- here("data", "clean-data", "study2_message_passing.csv") %>%
  read_csv(show_col_types = FALSE) %>%
  filter(
    two_correct_options == TRUE,
    shortest_path_given_opts == shortest_path_given_start_end
  ) %>%
  mutate(
    study = "Study 2",
    measurement_id = case_when(
      network == "learned" ~ str_c("D", measurement_id),
      network == "reevaluated" ~ "D2b"
    ),
    shortest_path = factor(shortest_path_given_opts)
  ) %>%
  select(
    study, sub_id, measurement_id, shortest_path,
    startpoint_id, endpoint_id,
    opt1_id, opt2_id,
    correct_choice, sub_choice,
    correct, rt
  ) %>%
  filter(measurement_id %in% c("D1", "D2"))

nav_study3_ties <- here("data", "clean-data", "study3_message_passing.csv") %>%
  read_csv(show_col_types = FALSE) %>%
  filter(
    two_correct_options == TRUE,
    shortest_path_given_opts == shortest_path_given_start_end
  ) %>%
  mutate(
    study = "Study 3",
    measurement_id = case_when(
      network == "reevaluated" ~ "D2b",
      measurement_id == 1 ~ "D1",
      measurement_id == 2 ~ "D1b",
      measurement_id == 3 ~ "D2"
    ),
    shortest_path = factor(shortest_path_given_opts)
  ) %>%
  select(
    study, sub_id, measurement_id, shortest_path,
    startpoint_id, endpoint_id,
    opt1_id, opt2_id,
    correct_choice, sub_choice,
    correct, rt
  ) %>%
  filter(measurement_id %in% c("D1", "D2"))
tie_item_analysis_humans <- bind_rows(
  nav_study1_ties, nav_study2_ties, nav_study3_ties
) %>%
  filter(measurement_id == "D1") %>%
  group_by(shortest_path, startpoint_id, endpoint_id, opt1_id, opt2_id) %>%
  summarise(
    p_human_opt1 = mean(sub_choice == opt1_id),
    .groups = "drop"
  )
bfs_avg_accuracy_ties <- here("data", "bfs-sims", "bfs_sims_learned.csv") %>%
  read_csv(show_col_types = FALSE) %>%
  filter(
    two_correct_options == TRUE,
    shortest_path_given_opts == shortest_path_given_start_end
  ) %>%
  mutate(shortest_path = factor(shortest_path_given_opts)) %>%
  select(
    shortest_path, startpoint_id, endpoint_id, opt1_id, opt2_id,
    bfs_choice, bfs_correct_choice, bfs_n_visits_total
  ) %>%
  group_by(shortest_path, startpoint_id, endpoint_id, opt1_id, opt2_id) %>%
  summarise(
    p_bfs_opt1 = mean(bfs_choice == opt1_id),
    bfs_visits = mean(bfs_n_visits_total),
    .groups = "drop"
  )

tie_item_analysis_bfs <- bind_rows(
  nav_study1_ties, nav_study2_ties, nav_study3_ties
) %>%
  filter(measurement_id == "D1") %>%
  left_join(
    params %>%
      filter(model == "bfs") %>%
      pivot_wider(names_from = param_name, values_from = param_value) %>%
      select(study, sub_id, measurement_id, search_threshold, lapse_rate)
  ) %>%
  left_join(bfs_avg_accuracy_ties) %>%
  mutate(
    p_sub_choice_bfs = if_else(
      sub_choice == opt1_id,
      p_bfs_opt1,
      1 - p_bfs_opt1
    )
  ) %>%
  select(-p_bfs_opt1) %>%
  # What's the probability of *completing* BFS-online all the way through?
  rowwise() %>%
  mutate(
    search_threshold = search_threshold,
    p_complete_bfs = softmax(
      option_values = c(search_threshold, bfs_visits),
      option_chosen = 1,
      temperature = 1
    )
  ) %>%
  ungroup() %>%
  # Weigh BFS predictions accordingly
  mutate(
    p_give_up = 1 - p_complete_bfs,
    p_bfs_opt1 = (p_complete_bfs * p_sub_choice_bfs) + (p_give_up * 1/2),
    ### Add lapse rate
    #   Dividing by 2 is because there are two options to choose from
    #   Therefore, when lapse rate = 1, this becomes chance = 1/2
    p_bfs_opt1 = p_bfs_opt1 * (1-lapse_rate) + (lapse_rate/2)
  ) %>%
  # Average over trials
  # group_by(shortest_path) %>%
  group_by(shortest_path, startpoint_id, endpoint_id, opt1_id, opt2_id) %>%
  summarise(
    p_bfs_opt1 = mean(p_bfs_opt1),
    .groups = "drop"
  )
## Joining with `by = join_by(study, sub_id, measurement_id)`
## Joining with `by = join_by(shortest_path, startpoint_id, endpoint_id, opt1_id, opt2_id)`
tie_item_analysis_sr <- bind_rows(
  nav_study1_ties, nav_study2_ties, nav_study3_ties
) %>%
  filter(measurement_id == "D1") %>%
  left_join(
    ppc_sr_matrix %>%
      select(
        study, sub_id, measurement_id,
        opt1_id = from,
        endpoint_id = to,
        opt1_sr = sr_value
      )
  ) %>%
  left_join(
    ppc_sr_matrix %>%
      select(
        study, sub_id, measurement_id,
        opt2_id = from,
        endpoint_id = to,
        opt2_sr = sr_value,
        sr_gamma, softmax_temperature, lapse_rate
      )
  ) %>%
  rowwise() %>%
  mutate(
    p_sr_opt1 = softmax(
      option_values = c(opt1_sr, opt2_sr) * 100,
      option_chosen = 1,
      temperature = softmax_temperature,
      use_inverse_temperature = TRUE,
      lapse_rate = lapse_rate
    )
  ) %>%
  group_by(shortest_path, startpoint_id, endpoint_id, opt1_id, opt2_id) %>%
  summarise(
    p_sr_opt1 = mean(p_sr_opt1, na.rm = TRUE),
    .groups = "drop"
  )
## Joining with `by = join_by(study, sub_id, measurement_id, endpoint_id, opt1_id)`
## Joining with `by = join_by(study, sub_id, measurement_id, endpoint_id, opt2_id)`
plot_ties <- tie_item_analysis_humans %>%
  left_join(tie_item_analysis_bfs) %>%
  left_join(tie_item_analysis_sr) %>%
  mutate(item = row_number(), item = factor(item)) %>%
  pivot_longer(
    starts_with("p_"),
    names_to = "agent",
    values_to = "p_choose_opt1"
  ) %>%
  mutate(
    agent = str_remove_all(agent, "p_|_opt1"),
    agent = case_when(
      agent == "human" ~ "Human",
      agent == "sr" ~ "SR",
      agent == "bfs" ~ "BFS",
    ),
    agent = fct_relevel(agent, "Human", "SR")
  ) %>%
  ggplot(aes(x=shortest_path, y=p_choose_opt1)) +
  theme_custom() +
  facet_wrap(~agent) +
  geom_hline(yintercept = 0.5, linetype = "dashed") +
  geom_point(alpha = 0.25) +
  stat_summary(geom = "crossbar", fun = mean, color = "red") +
  scale_x_discrete(name = "Shortest path distance") +
  scale_y_continuous(
    name = "p(Choose Source A > B)",
    labels = scales::percent, breaks = seq(0, 1, 0.25)
  ) +
  theme(
    panel.grid = element_blank(),
    legend.position = "bottom"
  ) +
  ggtitle("Navigation problems with two correct answers")
## Joining with `by = join_by(shortest_path, startpoint_id, endpoint_id, opt1_id, opt2_id)`
## Joining with `by = join_by(shortest_path, startpoint_id, endpoint_id, opt1_id, opt2_id)`
plot_ties

if (knitting) {
  ggsave(
    here("figures", str_c("navigation_heldout_day1", ".pdf")),
    plot = plot_ties,
    width = 8/2, height = 2.5,
    units = "in", dpi = 300
  )
}

Does navigation improve after rest?

First, some descriptives…

bind_rows(nav_study1, nav_study2, nav_study3) %>%
  group_by(measurement_id, shortest_path) %>%
  summarise(accuracy = mean(correct), .groups = "drop") %>%
  arrange(measurement_id, shortest_path) %>%
  pivot_wider(
    names_from = shortest_path, values_from = accuracy, names_prefix = "dist-"
  ) %>%
  kbl(
    caption = str_c(
      "<center>", "Descriptive: Navigation accuracy", "</center>"
    ),
    digits = 2
  ) %>%
  kable_styling(bootstrap_options = c("responsive"))
Descriptive: Navigation accuracy
measurement_id dist-2 dist-3 dist-4
D1 0.80 0.70 0.63
D1b 0.79 0.72 0.68
D2 0.82 0.75 0.71
D2b 0.81 0.72 0.62

Overnight rest

Now we want to see how navigation accuracy changes from day 1 to day 2. This pools across studies 2-3 (study 1 was only a one-day experiment).

nav_day1_to_day2 <- bind_rows(nav_study2, nav_study3) %>%
  filter(measurement_id %in% c("D1", "D2")) %>%
  # Give every subject a distinct identifier
  mutate(sub_id = str_c(study, " s", sub_id))

stats_nav_day2_dist2 <- nav_day1_to_day2 %>%
  mutate(shortest_path = fct_relevel(shortest_path, "2")) %>%
  glmmTMB(
    correct ~ shortest_path * measurement_id +
      (1 + shortest_path + measurement_id | sub_id) + (1 | study),
    family = binomial,
    data = .
  )

stats_nav_day2_dist3 <- nav_day1_to_day2 %>%
  mutate(shortest_path = fct_relevel(shortest_path, "3")) %>%
  glmmTMB(
    correct ~ shortest_path * measurement_id +
      (1 + shortest_path + measurement_id | sub_id) + (1 | study),
    family = binomial,
    data = .
  )

stats_nav_day2_dist4 <- nav_day1_to_day2 %>%
  mutate(shortest_path = fct_relevel(shortest_path, "4")) %>%
  glmmTMB(
    correct ~ shortest_path * measurement_id +
      (1 + shortest_path + measurement_id | sub_id) + (1 | study),
    family = binomial,
    data = .
  )

map_dfr(
  .x = list(
    "dist-2" = stats_nav_day2_dist2,
    "dist-3" = stats_nav_day2_dist3,
    "dist-4" = stats_nav_day2_dist4
  ),
  .f = ~tidy(.x, conf.int = TRUE),
  .id = "ref_cat"
) %>%
  check_significance() %>%
  select(-c(ref_cat, effect, component)) %>%
  kbl(
    caption = str_c(
      "<center>", "Navigation accuracy: Day 1 to Day 2", "</center>"
    ),
    digits = 3
  ) %>%
  kable_styling(bootstrap_options = c("responsive")) %>%
  pack_rows("Ref. Cat. dist-2", 1, 17) %>%
  pack_rows("Ref. Cat. dist-3", 18, 34) %>%
  pack_rows("Ref. Cat. dist-4", 35, 51)
Navigation accuracy: Day 1 to Day 2
group term estimate std.error statistic p.value conf.low conf.high sig
Ref. Cat. dist-2
NA (Intercept) 1.765 0.114 15.476 0.000 1.541 1.988 ***
NA shortest_path3 -0.452 0.083 -5.452 0.000 -0.615 -0.290 ***
NA shortest_path4 -0.959 0.091 -10.554 0.000 -1.137 -0.781 ***
NA measurement_idD2 0.234 0.080 2.906 0.004 0.076 0.391 **
NA shortest_path3:measurement_idD2 0.030 0.087 0.351 0.725 -0.139 0.200
NA shortest_path4:measurement_idD2 0.196 0.083 2.356 0.018 0.033 0.359
sub_id sd__(Intercept) 1.024 NA NA NA 0.858 1.223
sub_id sd__shortest_path3 0.461 NA NA NA 0.334 0.635
sub_id sd__shortest_path4 0.634 NA NA NA 0.512 0.785
sub_id sd__measurement_idD2 0.483 NA NA NA 0.370 0.631
sub_id cor__(Intercept).shortest_path3 0.182 NA NA NA -0.175 0.481
sub_id cor__(Intercept).shortest_path4 -0.250 NA NA NA -0.799 0.057
sub_id cor__(Intercept).measurement_idD2 0.451 NA NA NA -0.037 0.608
sub_id cor__shortest_path3.shortest_path4 0.808 NA NA NA 0.311 0.873
sub_id cor__shortest_path3.measurement_idD2 0.466 NA NA NA -0.116 0.837
sub_id cor__shortest_path4.measurement_idD2 0.007 NA NA NA -0.520 0.712
study sd__(Intercept) 0.000 NA NA NA 0.000 Inf
Ref. Cat. dist-3
NA (Intercept) 1.312 0.134 9.810 0.000 1.050 1.574 ***
NA shortest_path2 0.452 0.083 5.452 0.000 0.290 0.615 ***
NA shortest_path4 -0.507 0.079 -6.411 0.000 -0.662 -0.352 ***
NA measurement_idD2 0.264 0.088 2.989 0.003 0.091 0.437 **
NA shortest_path2:measurement_idD2 -0.030 0.087 -0.351 0.725 -0.200 0.139
NA shortest_path4:measurement_idD2 0.166 0.090 1.833 0.067 -0.011 0.343 .
sub_id sd__(Intercept) 1.197 NA NA NA 0.997 1.437
sub_id sd__shortest_path2 0.461 NA NA NA 0.334 0.635
sub_id sd__shortest_path4 0.377 NA NA NA 0.266 0.536
sub_id sd__measurement_idD2 0.483 NA NA NA 0.370 0.631
sub_id cor__(Intercept).shortest_path2 -0.541 NA NA NA -0.743 -0.174
sub_id cor__(Intercept).shortest_path4 -0.497 NA NA NA -0.605 0.049
sub_id cor__(Intercept).measurement_idD2 0.565 NA NA NA -0.045 0.827
sub_id cor__shortest_path2.shortest_path4 -0.137 NA NA NA -0.002 0.234
sub_id cor__shortest_path2.measurement_idD2 -0.466 NA NA NA -0.209 0.025
sub_id cor__shortest_path4.measurement_idD2 -0.557 NA NA NA -0.093 0.401
study sd__(Intercept) 0.000 NA NA NA 0.000 Inf
Ref. Cat. dist-4
NA (Intercept) 0.805 0.118 6.827 0.000 0.574 1.037 ***
NA shortest_path2 0.959 0.091 10.554 0.000 0.781 1.137 ***
NA shortest_path3 0.507 0.079 6.411 0.000 0.352 0.662 ***
NA measurement_idD2 0.430 0.083 5.174 0.000 0.267 0.592 ***
NA shortest_path2:measurement_idD2 -0.196 0.083 -2.356 0.018 -0.359 -0.033
NA shortest_path3:measurement_idD2 -0.166 0.090 -1.834 0.067 -0.343 0.011 .
sub_id sd__(Intercept) 1.061 NA NA NA 0.883 1.276
sub_id sd__shortest_path2 0.634 NA NA NA 0.512 0.785
sub_id sd__shortest_path3 0.377 NA NA NA 0.265 0.536
sub_id sd__measurement_idD2 0.483 NA NA NA 0.370 0.631
sub_id cor__(Intercept).shortest_path2 -0.356 NA NA NA -0.587 -0.038
sub_id cor__(Intercept).shortest_path3 0.205 NA NA NA -0.465 0.400
sub_id cor__(Intercept).measurement_idD2 0.439 NA NA NA -0.096 0.450
sub_id cor__shortest_path2.shortest_path3 0.695 NA NA NA 0.448 0.848
sub_id cor__shortest_path2.measurement_idD2 -0.007 NA NA NA -0.118 0.214
sub_id cor__shortest_path3.measurement_idD2 0.557 NA NA NA -0.497 0.628
study sd__(Intercept) 0.000 NA NA NA 0.000 Inf
predict_nav_day1_to_day2 <- expand_grid(
  measurement_id = c("D1", "D2"),
  shortest_path = factor(2:4),
  sub_id = NA, study = NA
) %>%
  predict_glmmTMB(stats_nav_day2_dist2)

plot_nav_day1_to_day2 <- nav_day1_to_day2 %>%
  group_by(sub_id, measurement_id, shortest_path) %>%
  summarise(accuracy = mean(correct), .groups = "drop") %>%
  ggplot(aes(x=shortest_path, y=accuracy, color=measurement_id)) +
  theme_custom() +
  geom_hline(yintercept = 0.5, linetype = "dashed") +
  geom_dotplot(
    aes(fill = measurement_id),
    binwidth = 0.01,
    binaxis = "y", stackdir = "center",
    position = position_dodge(width = 0.75),
    dotsize = 1, alpha = 0.5, color = NA,
    show.legend = FALSE
  ) +
  geom_pointrange(
    aes(
      x = shortest_path, y = fit,
      ymin = fit - se.fit, ymax = fit + se.fit,
      color = measurement_id
    ),
    data = predict_nav_day1_to_day2, inherit.aes = FALSE, show.legend = FALSE,
    position = position_dodge(width = 0.15), linewidth = 1
  ) +
  geom_line(
    aes(
      x = shortest_path, y = fit,
      group = measurement_id, color = measurement_id
    ),
    data = predict_nav_day1_to_day2, inherit.aes = FALSE,
    position = position_dodge(width = 0.15), linewidth = 1
  ) +
  scale_x_discrete(name = "Shortest path distance") +
  scale_y_continuous(
    name = "Accuracy", labels = scales::percent, breaks = seq(0, 1, 0.25)
  ) +
  scale_color_manual(
    name = "Measurement",
    values = c("D1"="#fa9fb5", "D2"="#7a0177"),
    labels = c("D1"="Before overnight rest", "D2"="After overnight rest")
  ) +
  scale_fill_manual(values = c("D1"="#fa9fb5", "D2"="#7a0177")) +
  coord_cartesian(ylim = c(0.3, 1.1)) +
  theme(legend.position = "bottom") +
  ggtitle("Navigation after overnight rest")

plot_nav_day1_to_day2

if (knitting) {
  ggsave(
    here("figures", str_c("navigation_day1_to_day2", ".pdf")),
    plot = plot_nav_day1_to_day2,
    width = 4, height = 3,
    units = "in", dpi = 300
  )
}

Awake rest

Is a brief period of awake rest sufficient for improving navigation accuracy?

stats_nav_awake_dist2 <- nav_study3 %>%
  filter(measurement_id %in% c("D1", "D1b")) %>%
  mutate(shortest_path = fct_relevel(shortest_path, "2")) %>%
  glmmTMB(
    correct ~ shortest_path * measurement_id +
      (1 + shortest_path + measurement_id | sub_id),
    family = binomial,
    data = .
  )

stats_nav_awake_dist3 <- nav_study3 %>%
  filter(measurement_id %in% c("D1", "D1b")) %>%
  mutate(shortest_path = fct_relevel(shortest_path, "3")) %>%
  glmmTMB(
    correct ~ shortest_path * measurement_id +
      (1 + shortest_path + measurement_id | sub_id),
    family = binomial,
    data = .
  )

stats_nav_awake_dist4 <- nav_study3 %>%
  filter(measurement_id %in% c("D1", "D1b")) %>%
  mutate(shortest_path = fct_relevel(shortest_path, "4")) %>%
  glmmTMB(
    correct ~ shortest_path * measurement_id +
      (1 + shortest_path + measurement_id | sub_id),
    family = binomial,
    data = .
  )

map_dfr(
  .x = list(
    "dist-2" = stats_nav_awake_dist2,
    "dist-3" = stats_nav_awake_dist3,
    "dist-4" = stats_nav_awake_dist4
  ),
  .f = ~tidy(.x, conf.int = TRUE),
  .id = "ref_cat"
) %>%
  check_significance() %>%
  select(-c(ref_cat, effect, component)) %>%
  kbl(
    caption = str_c("<center>", "Navigation accuracy: Awake Rest", "</center>"),
    digits = 3
  ) %>%
  kable_styling(bootstrap_options = c("responsive")) %>%
  pack_rows("Ref. Cat. dist-2", 1, 16) %>%
  pack_rows("Ref. Cat. dist-3", 17, 32) %>%
  pack_rows("Ref. Cat. dist-4", 33, 48)
Navigation accuracy: Awake Rest
group term estimate std.error statistic p.value conf.low conf.high sig
Ref. Cat. dist-2
NA (Intercept) 1.840 0.193 9.542 0.000 1.462 2.218 ***
NA shortest_path3 -0.424 0.129 -3.290 0.001 -0.677 -0.171 **
NA shortest_path4 -0.936 0.134 -6.972 0.000 -1.199 -0.673 ***
NA measurement_idD1b -0.058 0.122 -0.475 0.635 -0.296 0.181
NA shortest_path3:measurement_idD1b 0.022 0.123 0.175 0.861 -0.220 0.263
NA shortest_path4:measurement_idD1b 0.244 0.119 2.060 0.039 0.012 0.477
sub_id sd__(Intercept) 1.216 NA NA NA 0.947 1.561
sub_id sd__shortest_path3 0.518 NA NA NA 0.352 0.764
sub_id sd__shortest_path4 0.632 NA NA NA 0.465 0.859
sub_id sd__measurement_idD1b 0.544 NA NA NA 0.395 0.751
sub_id cor__(Intercept).shortest_path3 -0.019 NA NA NA -0.462 0.435
sub_id cor__(Intercept).shortest_path4 -0.262 NA NA NA -0.840 0.127
sub_id cor__(Intercept).measurement_idD1b 0.045 NA NA NA -0.264 0.378
sub_id cor__shortest_path3.shortest_path4 0.845 NA NA NA 0.444 0.913
sub_id cor__shortest_path3.measurement_idD1b 0.191 NA NA NA -0.062 0.647
sub_id cor__shortest_path4.measurement_idD1b -0.024 NA NA NA -0.184 0.641
Ref. Cat. dist-3
NA (Intercept) 1.416 0.211 6.714 0.000 1.003 1.829 ***
NA shortest_path2 0.424 0.129 3.290 0.001 0.171 0.677 **
NA shortest_path4 -0.512 0.116 -4.427 0.000 -0.738 -0.285 ***
NA measurement_idD1b -0.036 0.133 -0.273 0.785 -0.296 0.224
NA shortest_path2:measurement_idD1b -0.022 0.123 -0.175 0.861 -0.263 0.220
NA shortest_path4:measurement_idD1b 0.223 0.129 1.722 0.085 -0.031 0.476 .
sub_id sd__(Intercept) 1.313 NA NA NA 1.010 1.705
sub_id sd__shortest_path2 0.518 NA NA NA 0.352 0.764
sub_id sd__shortest_path4 0.338 NA NA NA 0.190 0.603
sub_id sd__measurement_idD1b 0.544 NA NA NA 0.395 0.751
sub_id cor__(Intercept).shortest_path2 -0.377 NA NA NA -0.704 0.175
sub_id cor__(Intercept).shortest_path4 -0.408 NA NA NA -0.681 0.342
sub_id cor__(Intercept).measurement_idD1b 0.117 NA NA NA -0.224 0.497
sub_id cor__shortest_path2.shortest_path4 -0.046 NA NA NA 0.106 0.502
sub_id cor__shortest_path2.measurement_idD1b -0.191 NA NA NA -0.124 0.355
sub_id cor__shortest_path4.measurement_idD1b -0.337 NA NA NA 0.009 0.560
Ref. Cat. dist-4
NA (Intercept) 0.904 0.193 4.684 0.000 0.526 1.283 ***
NA shortest_path2 0.936 0.134 6.971 0.000 0.673 1.199 ***
NA shortest_path3 0.512 0.116 4.427 0.000 0.285 0.738 ***
NA measurement_idD1b 0.187 0.126 1.476 0.140 -0.061 0.434
NA shortest_path2:measurement_idD1b -0.244 0.119 -2.060 0.039 -0.477 -0.012
NA shortest_path3:measurement_idD1b -0.223 0.129 -1.722 0.085 -0.476 0.031 .
sub_id sd__(Intercept) 1.214 NA NA NA 0.935 1.577
sub_id sd__shortest_path2 0.632 NA NA NA 0.465 0.859
sub_id sd__shortest_path3 0.338 NA NA NA 0.190 0.603
sub_id sd__measurement_idD1b 0.544 NA NA NA 0.395 0.751
sub_id cor__(Intercept).shortest_path2 -0.258 NA NA NA -0.602 0.216
sub_id cor__(Intercept).shortest_path3 0.163 NA NA NA -0.621 0.521
sub_id cor__(Intercept).measurement_idD1b 0.033 NA NA NA -0.386 0.307
sub_id cor__shortest_path2.shortest_path3 0.573 NA NA NA 0.343 0.835
sub_id cor__shortest_path2.measurement_idD1b 0.024 NA NA NA -0.041 0.338
sub_id cor__shortest_path3.measurement_idD1b 0.337 NA NA NA -0.003 0.674
predict_nav_awake <- expand_grid(
  measurement_id = c("D1", "D1b"),
  shortest_path = factor(2:4),
  sub_id = NA, study = NA
) %>%
  predict_glmmTMB(stats_nav_awake_dist2)

plot_nav_awake <- nav_study3 %>%
  filter(measurement_id %in% c("D1", "D1b")) %>%
  group_by(sub_id, measurement_id, shortest_path) %>%
  summarise(accuracy = mean(correct), .groups = "drop") %>%
  ggplot(aes(x=shortest_path, y=accuracy, color=measurement_id)) +
  theme_custom() +
  geom_hline(yintercept = 0.5, linetype = "dashed") +
  geom_dotplot(
    aes(fill = measurement_id),
    binwidth = 0.01,
    binaxis = "y", stackdir = "center",
    position = position_dodge(width = 0.75),
    dotsize = 1, alpha = 0.5, color = NA,
    show.legend = FALSE
  ) +
  geom_pointrange(
    aes(
      x = shortest_path, y = fit,
      ymin = fit - se.fit, ymax = fit + se.fit,
      color = measurement_id
    ),
    data = predict_nav_awake, inherit.aes = FALSE, show.legend = FALSE,
    position = position_dodge(width = 0.15), linewidth = 1
  ) +
  geom_line(
    aes(
      x = shortest_path, y = fit,
      group = measurement_id, color = measurement_id
    ),
    data = predict_nav_awake, inherit.aes = FALSE,
    position = position_dodge(width = 0.15), linewidth = 1
  ) +
  scale_x_discrete(name = "Shortest path distance") +
  scale_y_continuous(
    name = "Accuracy", labels = scales::percent, breaks = seq(0, 1, 0.25)
  ) +
  scale_color_manual(
    name = "Measurement",
    values = c("D1"="#fa9fb5", "D1b"="#2c7fb8"),
    labels = c("D1"="Before overnight rest", "D1b"="After awake rest")
  ) +
  scale_fill_manual(values = c("D1"="#fa9fb5", "D1b"="#2c7fb8")) +
  coord_cartesian(ylim = c(0.3, 1.1)) +
  theme(legend.position = "bottom") +
  ggtitle("Navigation after awake rest")

plot_nav_awake

if (knitting) {
  ggsave(
    here("figures", str_c("navigation_awake_rest", ".pdf")),
    plot = plot_nav_awake,
    width = 4, height = 3,
    units = "in", dpi = 300
  )
}

SR replay simulation

Before starting to do any simulation, we first want to know how much replay an agent can fit into different periods of time. Based on past research measuring neural replay events, we’ll assume that it takes about 50ms for the brain to replay a single item from a sequence. This will let us calculate the total number of items that could be replayed in each epoch. To simplify the process of actually running the simulation, we’ll convert this quantity into the number of “sets” that can be replayed, where a single “set” consists of the total number of relationships in the network (i.e., the number of undirected edges = 17, meaning that each set consists of 34 observations because each set contains observations of both A->B and B->A).

tibble(minutes_for_replay = c(1, 5, 15, 30, 60, 120)) %>%
  mutate(
    n_items = minutes_for_replay / (0.05 / 60),
    n_sets = n_items / 34
  ) %>%
  kbl(
    caption = str_c("<center>", "SR replay time", "</center>"),
    digits = 2
  ) %>%
  kable_styling(bootstrap_options = c("responsive"))
SR replay time
minutes_for_replay n_items n_sets
1 1200 35.29
5 6000 176.47
15 18000 529.41
30 36000 1058.82
60 72000 2117.65
120 144000 4235.29
sim_replay_0_min <- simulated_paired_associates %>%
  filter(set <= 6) %>%
  simulate_sr() %>%
  mutate(replay_time = 0)

sim_replay_1_min <- simulated_paired_associates %>%
  filter(set <= 35 + 6) %>%
  simulate_sr() %>%
  mutate(replay_time = 1)

sim_replay_5_min <- simulated_paired_associates %>%
  filter(set <= 176 + 6) %>%
  simulate_sr() %>%
  mutate(replay_time = 5)

sim_replay_60_min <- simulated_paired_associates %>%
  filter(set <= 2117 + 6) %>%
  simulate_sr() %>%
  mutate(replay_time = 60)
plot_sim_sr_replay <- bind_rows(
  join_sr(triallist_nav_learned, sim_replay_0_min),
  join_sr(triallist_nav_learned, sim_replay_1_min),
  join_sr(triallist_nav_learned, sim_replay_5_min),
  join_sr(triallist_nav_learned, sim_replay_60_min)
) %>%
  # Feed values through softmax
  mutate(across(c(opt1_sr, opt2_sr), ~.x * 100)) %>%
  expand_grid(temperature = 1) %>%
  rowwise() %>%
  mutate(
    p_correct = softmax(
      option_values = c(opt1_sr, opt2_sr),
      option_chosen = if_else(correct_choice == opt1_id, 1, 2),
      temperature = temperature,
      use_inverse_temperature = TRUE
    )
  ) %>%
  ungroup() %>%
  # For plotting
  mutate(
    gamma = factor(gamma),
    temperature = str_pad(temperature, width = 2, side = "left"),
    temperature = str_c(unicode_greek["tau"], " = ", temperature),
    replay_time = str_pad(replay_time, width = 2, side = "left"),
    replay_time = case_when(
      str_detect(replay_time, "1$") ~ str_c(replay_time, " minute of replay"),
      replay_time == " 0" ~ "No replay",
      TRUE ~ str_c(replay_time, " minutes of replay")
    ),
    replay_time = fct_relevel(replay_time, "No replay")
  ) %>%
  group_by(replay_time, gamma, temperature, shortest_path) %>%
  summarise(p_correct = mean(p_correct), .groups = "drop") %>%
  # Plot
  ggplot(aes(x=shortest_path, y=p_correct, color=gamma)) +
  theme_custom() +
  facet_grid(cols = vars(replay_time), rows = vars(temperature)) +
  geom_hline(yintercept = 0.5, linetype = "dashed") +
  geom_line(aes(color = gamma, group = gamma), linewidth = 0.8) +
  scale_x_discrete(name = "Shortest path distance") +
  scale_y_continuous(name = "Accuracy", labels = scales::percent) +
  scale_color_viridis_d(
    name = str_c(unicode_greek["gamma"], " = "), option = "turbo", end = 0.9
  ) +
  guides(color = guide_legend(byrow = TRUE, nrow = 1)) +
  coord_cartesian(ylim = c(0.5, 1)) +
  theme(legend.position = "bottom") +
  ggtitle("Simulated effects of SR replay on navigation")
## Joining with `by = join_by(endpoint_id, opt1_id)`
## Warning in left_join(., simulated_sr %>% rename(opt1_sr = sr_value, opt1_id = from, : Detected an unexpected many-to-many relationship between `x` and `y`.
## ℹ Row 1 of `x` matches multiple rows in `y`.
## ℹ Row 45 of `y` matches multiple rows in `x`.
## ℹ If a many-to-many relationship is expected, set `relationship = "many-to-many"` to silence this
##   warning.
## Joining with `by = join_by(endpoint_id, opt2_id, gamma, replay_time)`
## Joining with `by = join_by(endpoint_id, opt1_id)`
## Warning in left_join(., simulated_sr %>% rename(opt1_sr = sr_value, opt1_id = from, : Detected an unexpected many-to-many relationship between `x` and `y`.
## ℹ Row 1 of `x` matches multiple rows in `y`.
## ℹ Row 45 of `y` matches multiple rows in `x`.
## ℹ If a many-to-many relationship is expected, set `relationship = "many-to-many"` to silence this
##   warning.
## Joining with `by = join_by(endpoint_id, opt2_id, gamma, replay_time)`
## Joining with `by = join_by(endpoint_id, opt1_id)`
## Warning in left_join(., simulated_sr %>% rename(opt1_sr = sr_value, opt1_id = from, : Detected an unexpected many-to-many relationship between `x` and `y`.
## ℹ Row 1 of `x` matches multiple rows in `y`.
## ℹ Row 45 of `y` matches multiple rows in `x`.
## ℹ If a many-to-many relationship is expected, set `relationship = "many-to-many"` to silence this
##   warning.
## Joining with `by = join_by(endpoint_id, opt2_id, gamma, replay_time)`
## Joining with `by = join_by(endpoint_id, opt1_id)`
## Warning in left_join(., simulated_sr %>% rename(opt1_sr = sr_value, opt1_id = from, : Detected an unexpected many-to-many relationship between `x` and `y`.
## ℹ Row 1 of `x` matches multiple rows in `y`.
## ℹ Row 45 of `y` matches multiple rows in `x`.
## ℹ If a many-to-many relationship is expected, set `relationship = "many-to-many"` to silence this
##   warning.
## Joining with `by = join_by(endpoint_id, opt2_id, gamma, replay_time)`
plot_sim_sr_replay

if (knitting) {
  ggsave(
    here("figures", str_c("simulated_replay", ".pdf")),
    plot = plot_sim_sr_replay,
    width = 7, height = 2.5,
    units = "in", dpi = 300,
    device = cairo_pdf
  )
}

PPC before/after rest

plot_ppc_day1_to_day2 <- ppc_sr_behavior %>%
  filter(study %in% c("Study 2", "Study 3")) %>%
  mutate(
    measurement_id = case_when(
      measurement_id == "D1" ~ "Before rest",
      measurement_id == "D2" ~ "After rest"
    ),
    measurement_id = fct_relevel(measurement_id, "Before rest")
  ) %>%
  pivot_longer(c(empirical, predicted)) %>%
  ggplot(aes(x=shortest_path, y=value)) +
  theme_custom() +
  facet_wrap(~measurement_id) +
  geom_hline(yintercept = 0.5, linetype = "dashed") +
  geom_point(
    aes(color = name),
    alpha = 0.05,
    position = position_jitterdodge(
      jitter.width = 0.2, jitter.height = 0, dodge.width = 0.5, seed = 1
    ),
    show.legend = FALSE
  ) +
  stat_summary(
    aes(color = name), geom = "crossbar", fun = mean,
    position = position_dodge(0.5)
  ) +
  scale_x_discrete(name = "Shortest path distance") +
  scale_y_continuous(
    name = "Accuracy", labels = scales::percent, breaks = seq(0, 1, 0.25)
  ) +
  scale_color_manual(
    name = NULL,
    values = c("empirical"="#fd8d3c", "predicted"="#bd0026"),
    labels = c("empirical"="Human", "predicted"="SR")
  ) +
  theme(legend.position = "bottom") +
  ggtitle("Posterior predictive check")

plot_ppc_day1_to_day2

if (knitting) {
  ggsave(
    here("figures", str_c("ppc_day2", ".pdf")),
    plot = plot_ppc_day1_to_day2,
    width = 8/3, height = 2.5,
    units = "in", dpi = 300
  )
}

Relating model parameters to navigation

Does estimated gamma significantly increase after overnight rest?

params %>%
  filter(
    study %in% c("Study 2", "Study 3"),
    measurement_id %in% c("D1", "D2"),
    model == "sr",
    param_name == "sr_gamma"
  ) %>%
  select(study, sub_id, measurement_id, sr_gamma = param_value) %>%
  group_by(measurement_id) %>%
  summarise(median_gamma = median(sr_gamma)) %>%
  kbl(
    caption = str_c(
      "<center>", "Median SR gamma before/after overnight rest", "</center>"
    ),
    digits = 3
  ) %>%
  kable_styling(bootstrap_options = c("responsive"))
Median SR gamma before/after overnight rest
measurement_id median_gamma
D1 0.512
D2 0.659
params %>%
  filter(
    study %in% c("Study 2", "Study 3"),
    measurement_id %in% c("D1", "D2"),
    model == "sr",
    param_name == "sr_gamma"
  ) %>%
  select(study, sub_id, measurement_id, sr_gamma = param_value) %>%
  pivot_wider(names_from = measurement_id, values_from = sr_gamma) %>%
  with(
    wilcox.test(
      D2, D1, alternative = "greater", paired = TRUE, conf.int = TRUE
    )
  ) %>%
  tidy() %>%
  kbl(
    caption = str_c(
      "<center>", "Increase in SR gamma after overnight rest", "</center>"
    ),
    digits = 3
  ) %>%
  kable_styling(bootstrap_options = c("responsive"))
Increase in SR gamma after overnight rest
estimate statistic p.value conf.low conf.high method alternative
0.064 2876 0.023 0.009 Inf Wilcoxon signed rank test with continuity correction greater
plot_gamma_change <- params %>%
  filter(
    study %in% c("Study 2", "Study 3"),
    measurement_id %in% c("D1", "D2"),
    model == "sr",
    param_name == "sr_gamma"
  ) %>%
  select(study, sub_id, measurement_id, sr_gamma = param_value) %>%
  ggplot(aes(x=measurement_id, y=sr_gamma, color=measurement_id)) +
  theme_custom() +
  geom_point(
    position = position_jitterdodge(
      jitter.width = 0.25, jitter.height = 0, dodge.width = 0.5, seed = 1
    ),
    alpha = 0.25,
    show.legend = FALSE
  ) +
  stat_summary(
    geom = "crossbar", fun = median, position = "dodge", show.legend = FALSE
  ) +
  scale_x_discrete(
    name = "Measurement",
    labels = c("D1"="Before rest", "D2"="After rest")
  ) +
  scale_y_continuous(
    name = str_c("Estimated ", unicode_greek["gamma"]),
    breaks = seq(0, 1, 0.25)
  ) +
  scale_color_manual(
    name = "Measurement",
    values = c("D1"="#fa9fb5", "D2"="#7a0177"),
    labels = c("D1"="Before rest", "D2"="After rest")
  ) +
  coord_cartesian(ylim = c(0, 1.1)) +
  theme(legend.position = "bottom") +
  ggtitle(
    str_c(
      unicode_greek["Delta"], unicode_greek["gamma"], " after overnight rest"
    )
  )

plot_gamma_change

if (knitting) {
  ggsave(
    here("figures", str_c("gamma_change", ".pdf")),
    plot = plot_gamma_change,
    width = 8/3, height = 2.5,
    units = "in", dpi = 300,
    device = cairo_pdf
  )
}

Are changes in estimated gamma related to changes in navigation behaviors?

bind_rows(nav_study2, nav_study3) %>%
  filter(measurement_id %in% c("D1", "D2")) %>%
  group_by(study, sub_id, measurement_id, shortest_path) %>%
  summarise(accuracy = mean(correct), .groups = "drop") %>%
  pivot_wider(names_from = measurement_id, values_from = accuracy) %>%
  mutate(delta_accuracy = D2 - D1) %>%
  select(-c(D1, D2)) %>%
  # 
  left_join(
    params %>%
      filter(
        study %in% c("Study 2", "Study 3"),
        measurement_id %in% c("D1", "D2"),
        model == "sr",
        param_name == "sr_gamma"
      ) %>%
      select(study, sub_id, measurement_id, sr_gamma = param_value) %>%
      pivot_wider(names_from = measurement_id, values_from = sr_gamma) %>%
      mutate(delta_sr = D2 - D1) %>%
      select(-c(D1, D2)),
    by = c("study", "sub_id")
  ) %>%
  group_by(shortest_path) %>%
  nest() %>%
  mutate(
    test = map(
      .x = data,
      .f = ~with(
        .x,
        cor.test(
          delta_accuracy, delta_sr,
          method = "spearman", exact = FALSE, alternative = "greater"
        )
      ) %>% tidy()
    )
  ) %>%
  unnest(test) %>%
  ungroup() %>%
  select(-data) %>%
  kbl(
    caption = str_c(
      "<center>", "∆ Accuracy ~ ∆ SR gamma (after overnight rest)", "</center>"
    ),
    digits = 3
  ) %>%
  kable_styling(bootstrap_options = c("responsive"))
∆ Accuracy ~ ∆ SR gamma (after overnight rest)
shortest_path estimate statistic p.value method alternative
2 -0.180 173977.52 0.960 Spearman’s rank correlation rho greater
3 0.205 117163.09 0.022 Spearman’s rank correlation rho greater
4 0.453 80620.94 0.000 Spearman’s rank correlation rho greater
plot_gamma_accuracy_change <- bind_rows(nav_study2, nav_study3) %>%
  filter(measurement_id %in% c("D1", "D2")) %>%
  group_by(study, sub_id, measurement_id, shortest_path) %>%
  summarise(accuracy = mean(correct), .groups = "drop") %>%
  pivot_wider(names_from = measurement_id, values_from = accuracy) %>%
  mutate(delta_accuracy = D2 - D1) %>%
  select(-c(D1, D2)) %>%
  # 
  left_join(
    params %>%
      filter(
        study %in% c("Study 2", "Study 3"),
        measurement_id %in% c("D1", "D2"),
        model == "sr",
        param_name == "sr_gamma"
      ) %>%
      select(study, sub_id, measurement_id, sr_gamma = param_value) %>%
      pivot_wider(names_from = measurement_id, values_from = sr_gamma) %>%
      mutate(delta_sr = D2 - D1) %>%
      select(-c(D1, D2)),
    by = c("study", "sub_id")
  ) %>%
  ggplot(aes(x=delta_sr, y=delta_accuracy, color=shortest_path)) +
  theme_custom() +
  geom_hline(yintercept = 0, linetype = "dashed") +
  geom_vline(xintercept = 0, linetype = "dashed") +
  geom_point(alpha = 0.25, show.legend = FALSE) +
  geom_smooth(method = "lm", se = FALSE, linewidth = 1.5) +
  scale_x_continuous(name = str_c("Change in ", unicode_greek["gamma"])) +
  scale_y_continuous(name = "Change in accuracy") +
  scale_color_manual(
    name = "Shortest path distance",
    values = c("#88CCEE", "#CC6677", "#DDCC77")
  ) +
  coord_cartesian(xlim = c(-1, 1.1)) +
  theme(legend.position = "bottom") +
  ggtitle(
    str_c(
      unicode_greek["Delta"], "Navigation ~ ",
      unicode_greek["Delta"], unicode_greek["gamma"]
    )
  )

plot_gamma_accuracy_change
## `geom_smooth()` using formula = 'y ~ x'

if (knitting) {
  suppressWarnings(
    ggsave(
      here("figures", str_c("gamma_accuracy_change", ".pdf")),
      plot = plot_gamma_accuracy_change,
      width = 8/3, height = 2.5,
      units = "in", dpi = 300
    )
  )
  
  ggsave(
    here("figures", str_c("gamma_accuracy_change_cairo", ".pdf")),
    plot = plot_gamma_accuracy_change,
    width = 8/3, height = 2.5,
    units = "in", dpi = 300,
    device = cairo_pdf
  )
}
## `geom_smooth()` using formula = 'y ~ x'
## `geom_smooth()` using formula = 'y ~ x'

Evidence of cached representation

Here, we’re trying to get a sense for what is being cached, and trying to see if there’s evidence that caching (as opposed to model-based planning) is the primary driver of navigation improvement after overnight rest.

Simulation visualization

butterfly_layout <- create_layout(g, layout = "stress")

plot_network_with_sr <- sim_sr_matrix_asymptotic %>%
  mutate(
    from_sorted = if_else(from < to, from, to),
    to_sorted = if_else(from < to, to, from)
  ) %>%
  group_by(gamma, from = from_sorted, to = to_sorted) %>%
  summarise(sr_value = mean(sr_value)) %>%
  filter(round(gamma, 1) %in% c(0.1, 0.5, 0.9)) %>%
  filter(sr_value > 0.05) %>%
  left_join(adjlist) %>%
  filter(from < to) %>%
  mutate(
    edge = factor(edge),
    gamma = str_c(unicode_greek["gamma"], " = ", gamma)
  ) %>%
  tbl_graph(edges = ., directed = FALSE) %>%
  mutate(name = row_number()) %>%
  ggraph("manual", x=butterfly_layout$x, y=butterfly_layout$y) +
  theme_network() +
  facet_edges(~gamma, ncol = 1) +
  geom_edge_link(aes(alpha = sr_value, color = edge)) +
  geom_node_label(aes(label = name)) +
  scale_edge_color_manual(
    name = NULL,
    values = c("0"="red", "1"="black"),
    labels = c("0"="Inferred connections", "1"="Observed connections")
  ) +
  scale_edge_alpha(name = "p(Target | Source)") +
  theme(
    legend.position = "bottom",
    legend.box = "vertical",
    strip.background = element_blank(),
    strip.text = element_text(size = 13)
  ) +
  ggtitle("Cognitive maps predicted by SR")
## `summarise()` has grouped output by 'gamma', 'from'. You can override using the `.groups` argument.
## Joining with `by = join_by(from, to)`
plot_network_with_sr

if (knitting) {
  suppressWarnings(
    ggsave(
      here("figures", str_c("network_with_sr", ".pdf")),
      plot = plot_network_with_sr,
      width = 3.5, height = 6,
      units = "in", dpi = 300
    )
  )
  
  ggsave(
    here("figures", str_c("network_with_sr_cairo", ".pdf")),
    plot = plot_network_with_sr,
    width = 3.5, height = 6,
    units = "in", dpi = 300,
    device = cairo_pdf
  )
}

Transition reevaluation

Following transition reevaluation, a caching account predicts that people’s navigation should get worse relative to their performance after overnight rest. In contrast, a planning account predicts that people’s navigation should not be greatly impacted by a relatively small set of changes.

stats_nav_tr_dist2 <- bind_rows(nav_study2, nav_study3) %>%
  filter(measurement_id %in% c("D1", "D2", "D2b")) %>%
  mutate(
    measurement_id = fct_relevel(measurement_id, "D2b"),
    shortest_path = fct_relevel(shortest_path, "2")
  ) %>%
  glmmTMB(
    correct ~ shortest_path * measurement_id +
      (1 + shortest_path + measurement_id | sub_id) + (1 | study),
    family = binomial,
    data = .
  )

stats_nav_tr_dist3 <- bind_rows(nav_study2, nav_study3) %>%
  filter(measurement_id %in% c("D1", "D2", "D2b")) %>%
  mutate(
    measurement_id = fct_relevel(measurement_id, "D2b"),
    shortest_path = fct_relevel(shortest_path, "3")
  ) %>%
  glmmTMB(
    correct ~ shortest_path * measurement_id +
      (1 + shortest_path + measurement_id | sub_id) + (1 | study),
    family = binomial,
    data = .
  )

stats_nav_tr_dist4 <- bind_rows(nav_study2, nav_study3) %>%
  filter(measurement_id %in% c("D1", "D2", "D2b")) %>%
  mutate(
    measurement_id = fct_relevel(measurement_id, "D2b"),
    shortest_path = fct_relevel(shortest_path, "4")
  ) %>%
  glmmTMB(
    correct ~ shortest_path * measurement_id +
      (1 + shortest_path + measurement_id | sub_id) + (1 | study),
    family = binomial,
    data = .
  )

map_dfr(
  .x = list(
    "dist-2" = stats_nav_tr_dist2,
    "dist-3" = stats_nav_tr_dist3,
    "dist-4" = stats_nav_tr_dist4
  ),
  .f = ~tidy(.x, conf.int = TRUE),
  .id = "ref_cat"
) %>%
  check_significance() %>%
  select(-c(ref_cat, effect, component)) %>%
  kbl(
    caption = str_c(
      "<center>", "Navigation accuracy: Transition Reevaluation", "</center>"
    ),
    digits = 3
  ) %>%
  kable_styling(bootstrap_options = c("responsive")) %>%
  pack_rows("Ref. Cat. dist-2", 1, 25) %>%
  pack_rows("Ref. Cat. dist-3", 26, 50) %>%
  pack_rows("Ref. Cat. dist-4", 51, 75)
Navigation accuracy: Transition Reevaluation
group term estimate std.error statistic p.value conf.low conf.high sig
Ref. Cat. dist-2
NA (Intercept) 1.600 0.110 14.558 0.000 1.385 1.815 ***
NA shortest_path3 -0.536 0.075 -7.183 0.000 -0.682 -0.390 ***
NA shortest_path4 -1.029 0.080 -12.837 0.000 -1.187 -0.872 ***
NA measurement_idD1 0.038 0.066 0.584 0.559 -0.091 0.168
NA measurement_idD2 0.181 0.068 2.657 0.008 0.048 0.315 **
NA shortest_path3:measurement_idD1 0.052 0.077 0.672 0.502 -0.099 0.203
NA shortest_path4:measurement_idD1 0.088 0.074 1.183 0.237 -0.058 0.234
NA shortest_path3:measurement_idD2 0.071 0.079 0.902 0.367 -0.083 0.225
NA shortest_path4:measurement_idD2 0.287 0.076 3.765 0.000 0.137 0.436 ***
sub_id sd__(Intercept) 0.728 NA NA NA 0.583 0.910
sub_id sd__shortest_path3 0.369 NA NA NA 0.267 0.509
sub_id sd__shortest_path4 0.437 NA NA NA 0.339 0.564
sub_id sd__measurement_idD1 0.275 NA NA NA 0.193 0.391
sub_id sd__measurement_idD2 0.285 NA NA NA 0.200 0.407
sub_id cor__(Intercept).shortest_path3 0.055 NA NA NA -0.307 0.397
sub_id cor__(Intercept).shortest_path4 -0.353 NA NA NA -0.936 0.071
sub_id cor__(Intercept).measurement_idD1 -0.296 NA NA NA -0.440 0.119
sub_id cor__(Intercept).measurement_idD2 0.414 NA NA NA -0.036 0.488
sub_id cor__shortest_path3.shortest_path4 0.865 NA NA NA 0.194 0.932
sub_id cor__shortest_path3.measurement_idD1 -0.181 NA NA NA -0.216 0.273
sub_id cor__shortest_path3.measurement_idD2 0.171 NA NA NA -0.246 0.510
sub_id cor__shortest_path4.measurement_idD1 -0.103 NA NA NA 0.245 0.333
sub_id cor__shortest_path4.measurement_idD2 -0.044 NA NA NA -0.188 0.433
sub_id cor__measurement_idD1.measurement_idD2 0.356 NA NA NA 0.560 0.808
study sd__(Intercept) 0.028 NA NA NA 0.004 0.181
Ref. Cat. dist-3
NA (Intercept) 1.064 0.124 8.567 0.000 0.821 1.307 ***
NA shortest_path2 0.536 0.075 7.183 0.000 0.390 0.682 ***
NA shortest_path4 -0.493 0.061 -8.036 0.000 -0.614 -0.373 ***
NA measurement_idD1 0.090 0.071 1.277 0.201 -0.048 0.229
NA measurement_idD2 0.252 0.073 3.446 0.001 0.109 0.396 ***
NA shortest_path2:measurement_idD1 -0.052 0.077 -0.672 0.501 -0.203 0.099
NA shortest_path4:measurement_idD1 0.036 0.078 0.459 0.646 -0.118 0.190
NA shortest_path2:measurement_idD2 -0.071 0.079 -0.902 0.367 -0.225 0.083
NA shortest_path4:measurement_idD2 0.216 0.080 2.684 0.007 0.058 0.374 **
sub_id sd__(Intercept) 0.834 NA NA NA 0.664 1.048
sub_id sd__shortest_path2 0.369 NA NA NA 0.267 0.509
sub_id sd__shortest_path4 0.220 NA NA NA 0.137 0.353
sub_id sd__measurement_idD1 0.275 NA NA NA 0.193 0.391
sub_id sd__measurement_idD2 0.285 NA NA NA 0.200 0.407
sub_id cor__(Intercept).shortest_path2 -0.490 NA NA NA -0.717 -0.095
sub_id cor__(Intercept).shortest_path4 -0.675 NA NA NA -0.753 0.183
sub_id cor__(Intercept).measurement_idD1 -0.338 NA NA NA -0.491 0.099
sub_id cor__(Intercept).measurement_idD2 0.437 NA NA NA -0.051 0.528
sub_id cor__shortest_path2.shortest_path4 -0.042 NA NA NA 0.134 0.401
sub_id cor__shortest_path2.measurement_idD1 0.181 NA NA NA 0.180 0.357
sub_id cor__shortest_path2.measurement_idD2 -0.172 NA NA NA -0.180 0.211
sub_id cor__shortest_path4.measurement_idD1 0.099 NA NA NA 0.324 0.614
sub_id cor__shortest_path4.measurement_idD2 -0.376 NA NA NA -0.005 0.537
sub_id cor__measurement_idD1.measurement_idD2 0.356 NA NA NA 0.549 0.793
study sd__(Intercept) 0.028 NA NA NA 0.004 0.181
Ref. Cat. dist-4
NA (Intercept) 0.571 0.106 5.360 0.000 0.362 0.779 ***
NA shortest_path2 1.029 0.080 12.836 0.000 0.872 1.187 ***
NA shortest_path3 0.493 0.061 8.036 0.000 0.373 0.614 ***
NA measurement_idD1 0.126 0.067 1.898 0.058 -0.004 0.257 .
NA measurement_idD2 0.468 0.070 6.722 0.000 0.332 0.605 ***
NA shortest_path2:measurement_idD1 -0.088 0.074 -1.183 0.237 -0.234 0.058
NA shortest_path3:measurement_idD1 -0.036 0.078 -0.459 0.646 -0.190 0.118
NA shortest_path2:measurement_idD2 -0.287 0.076 -3.765 0.000 -0.436 -0.137 ***
NA shortest_path3:measurement_idD2 -0.216 0.080 -2.684 0.007 -0.374 -0.058 **
sub_id sd__(Intercept) 0.705 NA NA NA 0.559 0.889
sub_id sd__shortest_path2 0.437 NA NA NA 0.339 0.564
sub_id sd__shortest_path3 0.220 NA NA NA 0.137 0.353
sub_id sd__measurement_idD1 0.275 NA NA NA 0.193 0.391
sub_id sd__measurement_idD2 0.285 NA NA NA 0.200 0.407
sub_id cor__(Intercept).shortest_path2 -0.256 NA NA NA -0.545 0.121
sub_id cor__(Intercept).shortest_path3 0.487 NA NA NA -0.449 0.629
sub_id cor__(Intercept).measurement_idD1 -0.370 NA NA NA -0.578 0.074
sub_id cor__(Intercept).measurement_idD2 0.400 NA NA NA -0.098 0.476
sub_id cor__shortest_path2.shortest_path3 0.538 NA NA NA 0.141 0.799
sub_id cor__shortest_path2.measurement_idD1 0.103 NA NA NA 0.104 0.288
sub_id cor__shortest_path2.measurement_idD2 0.044 NA NA NA -0.147 0.342
sub_id cor__shortest_path3.measurement_idD1 -0.099 NA NA NA -0.110 0.437
sub_id cor__shortest_path3.measurement_idD2 0.376 NA NA NA -0.487 0.641
sub_id cor__measurement_idD1.measurement_idD2 0.356 NA NA NA 0.407 0.819
study sd__(Intercept) 0.028 NA NA NA 0.004 0.181
predict_tr <- expand_grid(
  measurement_id = c("D1", "D2", "D2b"),
  shortest_path = factor(2:4),
  sub_id = NA, study = NA
) %>%
  predict_glmmTMB(stats_nav_tr_dist2)

plot_nav_tr <- nav_study3 %>%
  filter(measurement_id %in% c("D1", "D2", "D2b")) %>%
  group_by(sub_id, measurement_id, shortest_path) %>%
  summarise(accuracy = mean(correct), .groups = "drop") %>%
  ggplot(aes(x=shortest_path, y=accuracy, color=measurement_id)) +
  theme_custom() +
  geom_hline(yintercept = 0.5, linetype = "dashed") +
  geom_dotplot(
    aes(fill = measurement_id),
    binwidth = 0.01,
    binaxis = "y", stackdir = "center",
    position = position_dodge(width = 0.75),
    dotsize = 1, alpha = 0.5, color = NA,
    show.legend = FALSE
  ) +
  geom_pointrange(
    aes(
      x = shortest_path, y = fit,
      ymin = fit - se.fit, ymax = fit + se.fit,
      color = measurement_id
    ),
    data = predict_tr, inherit.aes = FALSE, show.legend = FALSE,
    position = position_dodge(width = 0.25), linewidth = 1
  ) +
  geom_line(
    aes(
      x = shortest_path, y = fit,
      group = measurement_id, color = measurement_id
    ),
    data = predict_tr, inherit.aes = FALSE,
    position = position_dodge(width = 0.25), linewidth = 1
  ) +
  scale_x_discrete(name = "Shortest path distance") +
  scale_y_continuous(
    name = "Accuracy", labels = scales::percent, breaks = seq(0, 1, 0.25)
  ) +
  scale_color_manual(
    name = "Measurement",
    values = c("D1"="#fa9fb5", "D2"="#7a0177", "D2b"="#238b45"),
    labels = c(
      "D1"="Before overnight rest",
      "D2"="After overnight rest",
      "D2b"="After reevaluation"
    )
  ) +
  scale_fill_manual(
    values = c("D1"="#fa9fb5", "D2"="#7a0177", "D2b"="#238b45")
  ) +
  coord_cartesian(ylim = c(0.3, 1.1)) +
  theme(legend.position = "bottom") +
  ggtitle("Navigation after transition reevaluation")

plot_nav_tr

if (knitting) {
  ggsave(
    here("figures", str_c("navigation_reevaluation", ".pdf")),
    plot = plot_nav_tr,
    width = 4, height = 4,
    units = "in", dpi = 300
  )
}

Session info

For reproducibility.

sessionInfo()
## R version 4.3.1 (2023-06-16)
## Platform: aarch64-apple-darwin20 (64-bit)
## Running under: macOS Sonoma 14.2
## 
## Matrix products: default
## BLAS:   /Library/Frameworks/R.framework/Versions/4.3-arm64/Resources/lib/libRblas.0.dylib 
## LAPACK: /Library/Frameworks/R.framework/Versions/4.3-arm64/Resources/lib/libRlapack.dylib;  LAPACK version 3.11.0
## 
## locale:
## [1] en_US.UTF-8/en_US.UTF-8/en_US.UTF-8/C/en_US.UTF-8/en_US.UTF-8
## 
## time zone: America/New_York
## tzcode source: internal
## 
## attached base packages:
## [1] stats     graphics  grDevices utils     datasets  methods   base     
## 
## other attached packages:
##  [1] tictoc_1.2          kableExtra_1.3.4    ggraph_2.1.0       
##  [4] tidygraph_1.2.3     broom.mixed_0.2.9.4 glmmTMB_1.1.7      
##  [7] here_1.0.1          lubridate_1.9.2     forcats_1.0.0      
## [10] stringr_1.5.0       dplyr_1.1.2         purrr_1.0.2        
## [13] readr_2.1.4         tidyr_1.3.0         tibble_3.2.1       
## [16] ggplot2_3.4.3       tidyverse_2.0.0    
## 
## loaded via a namespace (and not attached):
##  [1] tidyselect_1.2.0    viridisLite_0.4.2   farver_2.1.1       
##  [4] viridis_0.6.4       fastmap_1.1.1       tweenr_2.0.2       
##  [7] digest_0.6.33       timechange_0.2.0    lifecycle_1.0.3    
## [10] magrittr_2.0.3      compiler_4.3.1      rlang_1.1.1        
## [13] sass_0.4.7          tools_4.3.1         igraph_1.5.1       
## [16] utf8_1.2.3          yaml_2.3.7          knitr_1.43         
## [19] labeling_0.4.2      graphlayouts_1.0.0  bit_4.0.5          
## [22] xml2_1.3.5          withr_2.5.0         numDeriv_2016.8-1.1
## [25] grid_4.3.1          polyclip_1.10-4     fansi_1.0.4        
## [28] colorspace_2.1-0    future_1.33.0       globals_0.16.2     
## [31] scales_1.2.1        MASS_7.3-60         cli_3.6.1          
## [34] crayon_1.5.2        rmarkdown_2.24      ragg_1.2.5         
## [37] generics_0.1.3      rstudioapi_0.15.0   httr_1.4.7         
## [40] tzdb_0.4.0          minqa_1.2.5         cachem_1.0.8       
## [43] ggforce_0.4.1       splines_4.3.1       rvest_1.0.3        
## [46] parallel_4.3.1      vctrs_0.6.3         boot_1.3-28.1      
## [49] webshot_0.5.5       Matrix_1.6-1        jsonlite_1.8.7     
## [52] hms_1.1.3           bit64_4.0.5         ggrepel_0.9.3      
## [55] listenv_0.9.0       systemfonts_1.0.4   jquerylib_0.1.4    
## [58] glue_1.6.2          parallelly_1.36.0   nloptr_2.0.3       
## [61] codetools_0.2-19    stringi_1.7.12      gtable_0.3.3       
## [64] lme4_1.1-34         munsell_0.5.0       furrr_0.3.1        
## [67] pillar_1.9.0        htmltools_0.5.6     R6_2.5.1           
## [70] TMB_1.9.6           textshaping_0.3.6   rprojroot_2.0.3    
## [73] vroom_1.6.3         evaluate_0.21       lattice_0.21-8     
## [76] highr_0.10          backports_1.4.1     broom_1.0.5        
## [79] bslib_0.5.1         Rcpp_1.0.11         svglite_2.1.1      
## [82] gridExtra_2.3       nlme_3.1-162        mgcv_1.8-42        
## [85] xfun_0.40           pkgconfig_2.0.3
---
title: "Analyses for main text"
output:
  html_document:
    code_download: true
    code_folding: hide
    toc: true
    toc_float:
      collapsed: true
---

# Setup

```{r libraries}
library(tidyverse)
library(here)

library(glmmTMB)
library(broom.mixed)

library(tidygraph)
library(ggraph)

library(kableExtra)
library(tictoc)
```

```{r custom-functions}
source(here("code", "utils", "ggplot_themes.R"))
source(here("code", "utils", "modeling_utils.R"))
source(here("code", "utils", "representation_utils.R"))
source(here("code", "utils", "unicode_greek.R"))

predict_glmmTMB <- function(make_predictions_for, model_object) {
  make_predictions_for %>%
    bind_cols(
      predict(
        object = model_object,
        newdata = .,
        re.form = NA, allow.new.levels = TRUE, se.fit = TRUE, type = "response"
      )
    )
}

check_significance <- function(tidy_df) {
  tidy_df %>%
    mutate(
      sig = case_when(
        p.value < 0.001 ~ "***",
        p.value < 0.01 ~ "**",
        p.value < 0.05 ~ "*",
        p.value < 0.1 ~ ".",
        TRUE ~ ""
      )
    )
}
```

```{r knitting-setup}
# To control when to save figures
knitting <- knitr::is_html_output()

if (knitting) {
  if (!dir.exists(here("figures"))) {
    dir.create(here("figures"))
  }
}
```


# Network visualizations

```{r load-network-data}
adjlist <- here("data", "clean-data", "adjlist_learned.csv") %>%
  read_csv(show_col_types = FALSE)

g <- adjlist %>%
  filter(from < to, edge == 1) %>%
  select(-edge) %>%
  tbl_graph(edges = ., directed = FALSE)

adjlist_reevaluated <- here("data", "clean-data", "adjlist_reevaluated.csv") %>%
  read_csv(show_col_types = FALSE)

g_reevaluated <- adjlist_reevaluated %>%
  filter(from < to, edge == 1) %>%
  select(-edge) %>%
  tbl_graph(edges = ., directed = FALSE)
```

```{r plot-networks}
plot_network_learned <- g %>%
  mutate(name = row_number()) %>%
  ggraph("stress") +
  theme_network() +
  geom_edge_link() +
  geom_node_label(aes(label = name))

plot_network_reevaluated <- g_reevaluated %>%
  mutate(name = row_number()) %>%
  ggraph("stress") +
  theme_network() +
  geom_edge_link() +
  geom_node_label(aes(label = name))

plot_network_learned
plot_network_reevaluated

if (knitting) {
  ggsave(
    here("figures", str_c("network_learned", ".pdf")),
    plot = plot_network_learned,
    width = 4, height = 2,
    units = "in", dpi = 300
  )
  
  ggsave(
    here("figures", str_c("network_reevaluated", ".pdf")),
    plot = plot_network_reevaluated,
    width = 4, height = 2,
    units = "in", dpi = 300
  )
}
```


# Can humans solve social navigation problems?

In the paper, we start by examining social navigation behaviors in a one-day session (Study 1), or in the first session of a two-day study (Studies 2-3). Across all three studies, the procedure is exactly identical; studies 2-3 are, in this part of the dataset, exact replications of study 1.

```{r load-navigation-data}
nav_study1 <- here("data", "clean-data", "study1_message_passing.csv") %>%
  read_csv(show_col_types = FALSE) %>%
  filter(
    two_correct_options == FALSE,
    shortest_path_given_opts == shortest_path_given_start_end
  ) %>%
  mutate(
    study = "Study 1",
    measurement_id = str_c("D", measurement_id),
    shortest_path = factor(shortest_path_given_opts)
  ) %>%
  select(
    study, sub_id, measurement_id, shortest_path,
    startpoint_id, endpoint_id,
    opt1_id, opt2_id,
    correct_choice, sub_choice,
    correct, rt
  )

nav_study2 <- here("data", "clean-data", "study2_message_passing.csv") %>%
  read_csv(show_col_types = FALSE) %>%
  filter(
    two_correct_options == FALSE,
    shortest_path_given_opts == shortest_path_given_start_end
  ) %>%
  mutate(
    study = "Study 2",
    measurement_id = case_when(
      network == "learned" ~ str_c("D", measurement_id),
      network == "reevaluated" ~ "D2b"
    ),
    shortest_path = factor(shortest_path_given_opts)
  ) %>%
  select(
    study, sub_id, measurement_id, shortest_path,
    startpoint_id, endpoint_id,
    opt1_id, opt2_id,
    correct_choice, sub_choice,
    correct, rt
  )

nav_study3 <- here("data", "clean-data", "study3_message_passing.csv") %>%
  read_csv(show_col_types = FALSE) %>%
  filter(
    two_correct_options == FALSE,
    shortest_path_given_opts == shortest_path_given_start_end
  ) %>%
  mutate(
    study = "Study 3",
    measurement_id = case_when(
      network == "reevaluated" ~ "D2b",
      measurement_id == 1 ~ "D1",
      measurement_id == 2 ~ "D1b",
      measurement_id == 3 ~ "D2"
    ),
    shortest_path = factor(shortest_path_given_opts)
  ) %>%
  select(
    study, sub_id, measurement_id, shortest_path,
    startpoint_id, endpoint_id,
    opt1_id, opt2_id,
    correct_choice, sub_choice,
    correct, rt
  )
```

We'll start off with some descriptive statistics of human behavior. To maximize statistical power, we will pool across studies whenever possible.

```{r descriptive-navigation-day1}
bind_rows(nav_study1, nav_study2, nav_study3) %>%
  filter(measurement_id == "D1") %>%
  group_by(measurement_id, shortest_path) %>%
  summarise(accuracy = mean(correct), .groups = "drop") %>%
  arrange(measurement_id, shortest_path) %>%
  pivot_wider(
    names_from = shortest_path, values_from = accuracy, names_prefix = "dist-"
  ) %>%
  kbl(
    caption = str_c(
      "<center>", "Descriptive: Navigation accuracy", "</center>"
    ),
    digits = 2
  ) %>%
  kable_styling(bootstrap_options = c("responsive"))
```

And now the inferential statistical tests. Note that we're interested in knowing whether navigation accuracy differs from chance at each distance, so we'll estimate the same statistical model multiple times, changing the reference category each time. Note that this only reparameterizes the model, such that the *same* variance is accounted for by different parameters; it does *not* change the total amount of variance accounted for.

```{r stats-day1}
nav_day1 <- bind_rows(nav_study1, nav_study2, nav_study3) %>%
  filter(measurement_id == "D1") %>%
  # Give every subject a distinct identifier
  mutate(sub_id = str_c(study, " s", sub_id))

stats_nav_day1_dist2 <- nav_day1 %>%
  mutate(shortest_path = fct_relevel(shortest_path, "2")) %>%
  glmmTMB(
    correct ~ shortest_path + (1 + shortest_path | sub_id) + (1 | study),
    family = binomial,
    data = .
  )

stats_nav_day1_dist3 <- nav_day1 %>%
  mutate(shortest_path = fct_relevel(shortest_path, "3")) %>%
  glmmTMB(
    correct ~ shortest_path + (1 + shortest_path | sub_id) + (1 | study),
    family = binomial,
    data = .
  )

stats_nav_day1_dist4 <- nav_day1 %>%
  mutate(shortest_path = fct_relevel(shortest_path, "4")) %>%
  glmmTMB(
    correct ~ shortest_path + (1 + shortest_path | sub_id) + (1 | study),
    family = binomial,
    data = .
  )

map_dfr(
  .x = list(
    "dist-2" = stats_nav_day1_dist2,
    "dist-3" = stats_nav_day1_dist3,
    "dist-4" = stats_nav_day1_dist4
  ),
  .f = ~tidy(.x, conf.int = TRUE),
  .id = "ref_cat"
) %>%
  check_significance() %>%
  select(-c(ref_cat, effect, component)) %>%
  kbl(
    caption = str_c("<center>", "Navigation accuracy: Day 1", "</center>"),
    digits = 3
  ) %>%
  kable_styling(bootstrap_options = c("responsive")) %>%
  pack_rows("Ref. Cat. dist-2", 1, 10) %>%
  pack_rows("Ref. Cat. dist-3", 11, 20) %>%
  pack_rows("Ref. Cat. dist-4", 21, 30)
```

We'll plot out the raw data, plus model predictions...

```{r plot-day1}
predict_nav_day1 <- expand_grid(
  measurement_id = "D1",
  shortest_path = factor(2:4),
  sub_id = NA, study = NA
) %>%
  predict_glmmTMB(stats_nav_day1_dist2)

plot_nav_day1 <- bind_rows(nav_study2, nav_study3) %>%
  filter(measurement_id == "D1") %>%
  group_by(sub_id, measurement_id, shortest_path) %>%
  summarise(accuracy = mean(correct), .groups = "drop") %>%
  ggplot(aes(x=shortest_path, y=accuracy)) +
  theme_custom() +
  geom_hline(yintercept = 0.5, linetype = "dashed") +
  geom_dotplot(
    binwidth = 0.01,
    binaxis = "y", stackdir = "center",
    position = position_dodge(width = 0.75),
    dotsize = 1, alpha = 0.5, color = NA,
    show.legend = FALSE
  ) +
  geom_pointrange(
    aes(x = shortest_path, y = fit, ymin = fit - se.fit, ymax = fit + se.fit),
    data = predict_nav_day1, inherit.aes = FALSE, show.legend = FALSE,
    position = position_dodge(width = 0.15), linewidth = 1
  ) +
  geom_line(
    aes(x = shortest_path, y = fit, group = measurement_id),
    data = predict_nav_day1, inherit.aes = FALSE,
    position = position_dodge(width = 0.15), linewidth = 1
  ) +
  scale_x_discrete(name = "Shortest path distance") +
  scale_y_continuous(
    name = "Accuracy", labels = scales::percent, breaks = seq(0, 1, 0.1)
  ) +
  coord_cartesian(ylim = c(0.4, 1.1)) +
  theme(legend.position = "bottom") +
  ggtitle("Human social navigation")

plot_nav_day1

if (knitting) {
  ggsave(
    here("figures", str_c("navigation_day1", ".pdf")),
    plot = plot_nav_day1,
    width = 8/3, height = 3,
    units = "in", dpi = 300
  )
}
```


# Computational models of social navigation

We'd like to have some mechanistic insights about how people solve navigation problems. To do this, we'll look at two candidate models of navigation: breadth-first search (BFS) and the Successor Representation (SR).

## Simulations

### BFS simulation

```{r load-bfs-data}
sim_bfs <- here("data", "bfs-sims", "bfs_sims_learned.csv") %>%
  read_csv(show_col_types = FALSE) %>%
  filter(
    two_correct_options == FALSE,
    shortest_path_given_opts == shortest_path_given_start_end
  ) %>%
  mutate(shortest_path = factor(shortest_path_given_opts)) %>%
  select(
    shortest_path, startpoint_id, endpoint_id, opt1_id, opt2_id,
    bfs_choice, bfs_correct_choice, bfs_n_visits_total
  )
```

In our implementation, we model the BFS agent having some "threshold" for searching through the network. This can be thought of as a "willingness to spend time/effort performing a search" threshold. Once that threshold is exceeded, it becomes increasingly likely that the agent gives up and chooses randomly.

To see what threshold values might be informative to look at, we'll look at the average number of "searches" that an agent must perform to make a (non-random) decision.

```{r avg-bfs-visits}
sim_bfs %>%
  group_by(shortest_path) %>%
  summarise(avg_n_searches = mean(bfs_n_visits_total)) %>%
  kbl(
    caption = str_c("<center>", "Average searches in online-BFS", "</center>"),
    digits = 2
  ) %>%
  kable_styling(bootstrap_options = c("responsive"))
```

And now we'll plot the model predictions...

```{r plot-bfs-sim}
bfs_avg_accuracy <- sim_bfs %>%
  group_by(shortest_path, startpoint_id, endpoint_id, opt1_id, opt2_id) %>%
  summarise(
    bfs_accuracy = mean(bfs_correct_choice),
    bfs_visits = mean(bfs_n_visits_total),
    .groups = "drop"
  )

plot_sim_bfs <- bfs_avg_accuracy %>%
  expand_grid(search_threshold = seq(2, 12, 2)) %>%
  rowwise() %>%
  mutate(
    # Note: p(BFS) is 1-p(give up)
    p_bfs = softmax(
      option_values = c(search_threshold, bfs_visits),
      option_chosen = 1,
      temperature = 1
    )
  ) %>%
  ungroup() %>%
  # Weigh the model predictions according to their likelihood
  mutate(
    p_give_up = 1 - p_bfs,
    bfs_threshold_accuracy = (p_bfs * bfs_accuracy) + (p_give_up * (1/2))
  ) %>%
  # Format for plotting
  mutate(
    search_threshold = str_pad(search_threshold, width = 2, side = "left")
  ) %>%
  bind_rows(
    bfs_avg_accuracy %>%
      mutate(
        search_threshold = "Never gives up",
        bfs_threshold_accuracy = bfs_accuracy
      )
  ) %>%
  # Now plot
  ggplot(
    aes(
      x=shortest_path, y=bfs_threshold_accuracy,
      color=search_threshold, group=search_threshold
    )
  ) +
  theme_custom() +
  geom_hline(yintercept = 0.5, linetype = "dashed") +
  stat_summary(geom = "line", fun = mean, linewidth = 1) +
  scale_x_discrete(name = "Shortest path distance") +
  scale_y_continuous(name = "Accuracy", labels = scales::percent) +
  scale_color_viridis_d(
    name = "Search threshold", option = "magma",
    begin = 0.1, end = 0.9, direction = -1
  ) +
  guides(color = guide_legend(byrow = TRUE, nrow = 1)) +
  coord_cartesian(ylim = c(0.5, 1)) +
  theme(
    panel.grid = element_blank(),
    legend.position = "bottom"
  ) +
  ggtitle("Simulated BFS navigation")

plot_sim_bfs

if (knitting) {
  ggsave(
    here("figures", str_c("simulated_bfs", ".pdf")),
    plot = plot_sim_bfs,
    width = 8/3, height = 3,
    units = "in", dpi = 300
  )
}
```

### SR navigation simulation

```{r get-triallist}
triallist_nav_learned <- nav_study1 %>%
  filter(sub_id == 1) %>%
  select(
    startpoint_id, endpoint_id, opt1_id, opt2_id,
    correct_choice, shortest_path
  ) %>%
  arrange(startpoint_id, endpoint_id)
```

We'll create some simulated learning observations.

```{r simulate-observations}
set.seed(sum(utf8ToInt("Jenny and me was like peas and carrots")))

simulated_paired_associates <- adjlist %>%
  filter(edge == 1) %>%
  select(from, to) %>%
  expand_grid(set = 1:5000, .) %>%
  group_by(set) %>%
  slice_sample(prop = 1) %>%
  ungroup()
```

And now we'll simulate an "asymptotic" SR and how it performs in the navigation task.

```{r sim-sr-asymptotic}
simulate_sr <- function(simulated_observations) {
  simulated_observations %>%
    expand_grid(gamma = seq(0.1, 0.9, 0.1)) %>%
    group_by(gamma) %>%
    nest() %>%
    mutate(
      sr = map(
        .x = data,
        .f = ~build_rep_sr(
          learning_data = .x, this_alpha = 0.1, this_gamma = gamma
        )
      )
    ) %>%
    unnest(sr) %>%
    ungroup() %>%
    select(-data)
}

join_sr <- function(navigation_triallist, simulated_sr) {
  navigation_triallist %>%
    left_join(
      simulated_sr %>%
        rename(opt1_sr = sr_value, opt1_id = from, endpoint_id = to)
    ) %>%
    left_join(
      simulated_sr %>%
        rename(opt2_sr = sr_value, opt2_id = from, endpoint_id = to)
    )
}

sim_sr_matrix_asymptotic <- simulated_paired_associates %>%
  simulate_sr()

sim_sr_behavior_asymptotic <- join_sr(
  triallist_nav_learned, sim_sr_matrix_asymptotic
) %>%
  # Feed values through softmax
  mutate(across(c(opt1_sr, opt2_sr), ~.x * 100)) %>%
  expand_grid(temperature = 1) %>%
  rowwise() %>%
  mutate(
    p_correct = softmax(
      option_values = c(opt1_sr, opt2_sr),
      option_chosen = if_else(correct_choice == opt1_id, 1, 2),
      temperature = temperature,
      use_inverse_temperature = TRUE
    )
  ) %>%
  ungroup()
```

```{r plot-sr-sim-asymptotic}
plot_sim_sr_asymptotic <- sim_sr_behavior_asymptotic %>%
  mutate(
    gamma = factor(gamma),
    temperature = str_pad(temperature, width = 2, side = "left"),
    temperature = str_c(unicode_greek["tau"], " = ", temperature)
  ) %>%
  group_by(gamma, temperature, shortest_path) %>%
  summarise(p_correct = mean(p_correct), .groups = "drop") %>%
  ggplot(aes(x=shortest_path, y=p_correct, color=gamma)) +
  theme_custom() +
  facet_grid(rows = vars(temperature)) +
  geom_hline(yintercept = 0.5, linetype = "dashed") +
  geom_line(aes(color = gamma, group = gamma), linewidth = 0.8) +
  scale_x_discrete(name = "Shortest path distance") +
  scale_y_continuous(name = "Accuracy", labels = scales::percent) +
  scale_color_viridis_d(
    name = str_c(unicode_greek["gamma"], " = "), option = "turbo", end = 0.9
  ) +
  guides(color = guide_legend(byrow = TRUE, nrow = 1)) +
  coord_cartesian(ylim = c(0.5, 1)) +
  theme(legend.position = "bottom") +
  ggtitle("Simulated SR navigation")

plot_sim_sr_asymptotic

if (knitting) {
  suppressWarnings(
    ggsave(
      here("figures", str_c("simulated_sr", ".pdf")),
      plot = plot_sim_sr_asymptotic,
      width = 8/3, height = 3,
      units = "in", dpi = 300,
    )
  )
  
  ggsave(
    here("figures", str_c("simulated_sr_cairo", ".pdf")),
    plot = plot_sim_sr_asymptotic +
      guides(color = guide_legend(byrow = TRUE, nrow = 2)),
    width = 8/3, height = 3,
    units = "in", dpi = 300,
    device = cairo_pdf
  )
}
```

## Model comparison

To do model comparison, we'll need to pull in the estimated parameters/likelihoods from the model-fitting.

```{r load-params}
clean_params_from_raw <- FALSE

if (clean_params_from_raw) {
  params <- here("data", "param-fits") %>%
    fs::dir_ls(regexp = "study[[:digit:]]_D[[:digit:]]b?") %>%
    map_dfr(
      .f = ~read_csv(.x, show_col_types = FALSE) %>%
        best_optim_run("dataframe"),
      .id = "filename"
    ) %>%
    mutate(
      study = str_extract(filename, "study[[:digit:]]"),
      study = str_remove(study, "study"),
      study = str_c("Study ", study),
      measurement_id = str_extract(filename, "_D[[:digit:]]b?"),
      measurement_id = str_remove(measurement_id, "_"),
      sub_id = str_extract(filename, "sub-[[:digit:]]+"),
      sub_id = str_remove(sub_id, "sub-"),
      sub_id = as.integer(sub_id),
      model = case_when(
        str_detect(filename, "_hybrid-") ~ "hybrid",
        str_detect(filename, "_sr-") ~ "sr",
        str_detect(filename, "_bfs-") ~ "bfs"
      )
    ) %>%
    select(
      study, sub_id, measurement_id,
      model,
      param_name, param_value = param_value_human_readable,
      neg_loglik = optim_value
    ) %>%
    arrange(study, sub_id, measurement_id, model)
  
  params %>%
    write_csv(here("data", "param-fits", "clean_param_fits.csv"))
} else {
  params <- here("data", "param-fits", "clean_param_fits.csv") %>%
    read_csv(show_col_types = FALSE)
}
```

```{r pxp-model-selection}
source(here("code", "utils", "bayesian_model_selection.R"))

pxp_results <- params %>%
  select(study, sub_id, measurement_id, model, neg_loglik) %>%
  distinct() %>%
  mutate(log_lik = -neg_loglik) %>%
  select(-neg_loglik) %>%
  pivot_wider(names_from = model, values_from = log_lik) %>%
  select(-sub_id) %>%
  group_by(study, measurement_id) %>%
  nest() %>%
  mutate(
    test = map(
      .x = data,
      .f = ~bayesian_model_selection(.x)
    )
  ) %>%
  unnest(test) %>%
  ungroup() %>%
  select(-data)

pxp_results %>%
  ggplot(aes(x=measurement_id, y=pxp, color=model)) +
  theme_custom() +
  facet_wrap(~study, scales = "free_x") +
  geom_point(
    position = position_dodge(width = 0.75)
  )

pxp_results %>%
  select(study, measurement_id, model, pxp) %>%
  mutate(pxp = as.numeric(pxp)) %>%
  pivot_wider(
    names_from = model,
    values_from = pxp,
    names_prefix = "pxp_"
  ) %>%
  kbl(
    caption = str_c("<center>", "PXP results", "</center>"),
    digits = 3
  ) %>%
  kable_styling(bootstrap_options = c("responsive"))
```

## Posterior predictive check

We'll start by simulating the model's predictions for each subject, given their estimated parameters.

```{r ppc-bfs}
ppc_bfs <- bind_rows(nav_study1, nav_study2, nav_study3) %>%
  filter(measurement_id %in% c("D1", "D2")) %>%
  left_join(
    params %>%
      filter(model == "bfs") %>%
      pivot_wider(names_from = param_name, values_from = param_value) %>%
      select(study, sub_id, measurement_id, search_threshold, lapse_rate)
  ) %>%
  left_join(bfs_avg_accuracy) %>%
  mutate(
    p_sub_choice_bfs = if_else(
      sub_choice == correct_choice,
      bfs_accuracy,
      1 - bfs_accuracy
    )
  ) %>%
  # What's the probability of *completing* BFS-online all the way through?
  rowwise() %>%
  mutate(
    search_threshold = search_threshold,
    p_complete_bfs = softmax(
      option_values = c(search_threshold, bfs_visits),
      option_chosen = 1,
      temperature = 1
    )
  ) %>%
  ungroup() %>%
  # Weigh BFS predictions accordingly
  mutate(
    p_give_up = 1 - p_complete_bfs,
    predicted_correct = (p_complete_bfs * p_sub_choice_bfs) + (p_give_up * 1/2),
    ### Add lapse rate
    #   Dividing by 2 is because there are two options to choose from
    #   Therefore, when lapse rate = 1, this becomes chance = 1/2
    predicted_correct = predicted_correct * (1-lapse_rate) + (lapse_rate/2)
  ) %>%
  # Average over trials
  group_by(study, sub_id, measurement_id, shortest_path) %>%
  summarise(
    empirical = mean(correct),
    predicted = mean(predicted_correct),
    .groups = "drop"
  )
```

```{r ppc-sr}
ppc_sr_matrix <- params %>%
  filter(model == "sr") %>%
  select(
    study, sub_id, measurement_id, name = param_name, value = param_value
  ) %>%
  pivot_wider() %>%
  group_by(study, sub_id, measurement_id) %>%
  nest() %>%
  mutate(
    sim_sr = map(
      .x = data,
      .f = ~build_rep_sr(
        learning_data = simulated_paired_associates %>% filter(set %in% 1:100),
        this_alpha = 0.1,
        this_gamma = .x$sr_gamma,
        bidirectional = TRUE
      )
    )
  ) %>%
  unnest(sim_sr) %>%
  unnest(data) %>%
  ungroup()

ppc_sr_behavior <- bind_rows(nav_study1, nav_study2, nav_study3) %>%
  filter(measurement_id %in% c("D1", "D2")) %>%
  left_join(
    ppc_sr_matrix %>%
      rename(opt1_id = from, endpoint_id = to, opt1_sr = sr_value)
  ) %>%
  left_join(
    ppc_sr_matrix %>%
      rename(opt2_id = from, endpoint_id = to, opt2_sr = sr_value)
  ) %>%
  rowwise() %>%
  mutate(
    predicted_correct = softmax(
      option_values = c(opt1_sr, opt2_sr) * 100,
      option_chosen = if_else(correct_choice == opt1_id, 1, 2),
      temperature = softmax_temperature,
      use_inverse_temperature = TRUE,
      lapse_rate = lapse_rate
    )
  ) %>%
  ungroup() %>%
  # Fix trials when the softmax becomes undefined
  mutate(
    predicted_correct = case_when(
      is.nan(predicted_correct) & (sub_choice == correct_choice) ~ 1,
      is.nan(predicted_correct) & (sub_choice != correct_choice) ~ 0,
      TRUE ~ predicted_correct
    )
  ) %>%
  # Average over trials
  group_by(study, sub_id, measurement_id, shortest_path) %>%
  summarise(
    empirical = mean(correct),
    predicted = mean(predicted_correct),
    .groups = "drop"
  )
```

Now that we have both PPCs, we can plot them side-by-side.

```{r plot-ppc-day1}
plot_ppc_day1 <- ppc_bfs %>%
  rename(predicted_bfs = predicted) %>%
  left_join(ppc_sr_behavior %>% rename(predicted_sr = predicted)) %>%
  pivot_longer(
    c(empirical, predicted_bfs, predicted_sr),
    names_to = "agent", values_to = "accuracy"
  ) %>%
  filter(measurement_id == "D1") %>%
  ggplot(aes(x=shortest_path, y=accuracy)) +
  theme_custom() +
  geom_hline(yintercept = 0.5, linetype = "dashed") +
  geom_point(
    aes(color = agent),
    alpha = 0.05,
    position = position_jitterdodge(
      jitter.width = 0.2, jitter.height = 0, dodge.width = 0.75, seed = 1
    ),
    show.legend = FALSE
  ) +
  stat_summary(
    aes(color = agent), geom = "crossbar", fun = mean,
    position = position_dodge(0.5)
  ) +
  scale_x_discrete(name = "Shortest path distance") +
  scale_y_continuous(
    name = "Accuracy", labels = scales::percent, breaks = seq(0, 1, 0.25)
  ) +
  scale_color_manual(
    name = NULL,
    values = c(
      "empirical"="#fd8d3c",
      "predicted_bfs"="#8c2d04",
      "predicted_sr"="#bd0026"
    ),
    labels = c("empirical"="Human", "predicted_bfs"="BFS", "predicted_sr"="SR")
  ) +
  theme(legend.position = "bottom") +
  ggtitle("Posterior predictive check")

plot_ppc_day1

if (knitting) {
  ggsave(
    here("figures", str_c("ppc_day1", ".pdf")),
    plot = plot_ppc_day1,
    width = 8/2, height = 2.5,
    units = "in", dpi = 300
  )
}
```

## Held-out trials

There were some trials where both Sources had equivalent shortest path distances from the Target. We would expect that BFS would be largely indifferent between the two Sources, but it is possible that humans and/or the SR would make other predictions.

```{r load-tie-data}
nav_study1_ties <- here("data", "clean-data", "study1_message_passing.csv") %>%
  read_csv(show_col_types = FALSE) %>%
  filter(
    two_correct_options == TRUE,
    shortest_path_given_opts == shortest_path_given_start_end
  ) %>%
  mutate(
    study = "Study 1",
    measurement_id = str_c("D", measurement_id),
    shortest_path = factor(shortest_path_given_opts)
  ) %>%
  select(
    study, sub_id, measurement_id, shortest_path,
    startpoint_id, endpoint_id,
    opt1_id, opt2_id,
    correct_choice, sub_choice,
    correct, rt
  ) %>%
  filter(measurement_id %in% c("D1", "D2"))

nav_study2_ties <- here("data", "clean-data", "study2_message_passing.csv") %>%
  read_csv(show_col_types = FALSE) %>%
  filter(
    two_correct_options == TRUE,
    shortest_path_given_opts == shortest_path_given_start_end
  ) %>%
  mutate(
    study = "Study 2",
    measurement_id = case_when(
      network == "learned" ~ str_c("D", measurement_id),
      network == "reevaluated" ~ "D2b"
    ),
    shortest_path = factor(shortest_path_given_opts)
  ) %>%
  select(
    study, sub_id, measurement_id, shortest_path,
    startpoint_id, endpoint_id,
    opt1_id, opt2_id,
    correct_choice, sub_choice,
    correct, rt
  ) %>%
  filter(measurement_id %in% c("D1", "D2"))

nav_study3_ties <- here("data", "clean-data", "study3_message_passing.csv") %>%
  read_csv(show_col_types = FALSE) %>%
  filter(
    two_correct_options == TRUE,
    shortest_path_given_opts == shortest_path_given_start_end
  ) %>%
  mutate(
    study = "Study 3",
    measurement_id = case_when(
      network == "reevaluated" ~ "D2b",
      measurement_id == 1 ~ "D1",
      measurement_id == 2 ~ "D1b",
      measurement_id == 3 ~ "D2"
    ),
    shortest_path = factor(shortest_path_given_opts)
  ) %>%
  select(
    study, sub_id, measurement_id, shortest_path,
    startpoint_id, endpoint_id,
    opt1_id, opt2_id,
    correct_choice, sub_choice,
    correct, rt
  ) %>%
  filter(measurement_id %in% c("D1", "D2"))
```

```{r ties-human}
tie_item_analysis_humans <- bind_rows(
  nav_study1_ties, nav_study2_ties, nav_study3_ties
) %>%
  filter(measurement_id == "D1") %>%
  group_by(shortest_path, startpoint_id, endpoint_id, opt1_id, opt2_id) %>%
  summarise(
    p_human_opt1 = mean(sub_choice == opt1_id),
    .groups = "drop"
  )
```

```{r ties-bfs}
bfs_avg_accuracy_ties <- here("data", "bfs-sims", "bfs_sims_learned.csv") %>%
  read_csv(show_col_types = FALSE) %>%
  filter(
    two_correct_options == TRUE,
    shortest_path_given_opts == shortest_path_given_start_end
  ) %>%
  mutate(shortest_path = factor(shortest_path_given_opts)) %>%
  select(
    shortest_path, startpoint_id, endpoint_id, opt1_id, opt2_id,
    bfs_choice, bfs_correct_choice, bfs_n_visits_total
  ) %>%
  group_by(shortest_path, startpoint_id, endpoint_id, opt1_id, opt2_id) %>%
  summarise(
    p_bfs_opt1 = mean(bfs_choice == opt1_id),
    bfs_visits = mean(bfs_n_visits_total),
    .groups = "drop"
  )

tie_item_analysis_bfs <- bind_rows(
  nav_study1_ties, nav_study2_ties, nav_study3_ties
) %>%
  filter(measurement_id == "D1") %>%
  left_join(
    params %>%
      filter(model == "bfs") %>%
      pivot_wider(names_from = param_name, values_from = param_value) %>%
      select(study, sub_id, measurement_id, search_threshold, lapse_rate)
  ) %>%
  left_join(bfs_avg_accuracy_ties) %>%
  mutate(
    p_sub_choice_bfs = if_else(
      sub_choice == opt1_id,
      p_bfs_opt1,
      1 - p_bfs_opt1
    )
  ) %>%
  select(-p_bfs_opt1) %>%
  # What's the probability of *completing* BFS-online all the way through?
  rowwise() %>%
  mutate(
    search_threshold = search_threshold,
    p_complete_bfs = softmax(
      option_values = c(search_threshold, bfs_visits),
      option_chosen = 1,
      temperature = 1
    )
  ) %>%
  ungroup() %>%
  # Weigh BFS predictions accordingly
  mutate(
    p_give_up = 1 - p_complete_bfs,
    p_bfs_opt1 = (p_complete_bfs * p_sub_choice_bfs) + (p_give_up * 1/2),
    ### Add lapse rate
    #   Dividing by 2 is because there are two options to choose from
    #   Therefore, when lapse rate = 1, this becomes chance = 1/2
    p_bfs_opt1 = p_bfs_opt1 * (1-lapse_rate) + (lapse_rate/2)
  ) %>%
  # Average over trials
  # group_by(shortest_path) %>%
  group_by(shortest_path, startpoint_id, endpoint_id, opt1_id, opt2_id) %>%
  summarise(
    p_bfs_opt1 = mean(p_bfs_opt1),
    .groups = "drop"
  )
```

```{r ties-sr}
tie_item_analysis_sr <- bind_rows(
  nav_study1_ties, nav_study2_ties, nav_study3_ties
) %>%
  filter(measurement_id == "D1") %>%
  left_join(
    ppc_sr_matrix %>%
      select(
        study, sub_id, measurement_id,
        opt1_id = from,
        endpoint_id = to,
        opt1_sr = sr_value
      )
  ) %>%
  left_join(
    ppc_sr_matrix %>%
      select(
        study, sub_id, measurement_id,
        opt2_id = from,
        endpoint_id = to,
        opt2_sr = sr_value,
        sr_gamma, softmax_temperature, lapse_rate
      )
  ) %>%
  rowwise() %>%
  mutate(
    p_sr_opt1 = softmax(
      option_values = c(opt1_sr, opt2_sr) * 100,
      option_chosen = 1,
      temperature = softmax_temperature,
      use_inverse_temperature = TRUE,
      lapse_rate = lapse_rate
    )
  ) %>%
  group_by(shortest_path, startpoint_id, endpoint_id, opt1_id, opt2_id) %>%
  summarise(
    p_sr_opt1 = mean(p_sr_opt1, na.rm = TRUE),
    .groups = "drop"
  )
```

```{r plot-ties}
plot_ties <- tie_item_analysis_humans %>%
  left_join(tie_item_analysis_bfs) %>%
  left_join(tie_item_analysis_sr) %>%
  mutate(item = row_number(), item = factor(item)) %>%
  pivot_longer(
    starts_with("p_"),
    names_to = "agent",
    values_to = "p_choose_opt1"
  ) %>%
  mutate(
    agent = str_remove_all(agent, "p_|_opt1"),
    agent = case_when(
      agent == "human" ~ "Human",
      agent == "sr" ~ "SR",
      agent == "bfs" ~ "BFS",
    ),
    agent = fct_relevel(agent, "Human", "SR")
  ) %>%
  ggplot(aes(x=shortest_path, y=p_choose_opt1)) +
  theme_custom() +
  facet_wrap(~agent) +
  geom_hline(yintercept = 0.5, linetype = "dashed") +
  geom_point(alpha = 0.25) +
  stat_summary(geom = "crossbar", fun = mean, color = "red") +
  scale_x_discrete(name = "Shortest path distance") +
  scale_y_continuous(
    name = "p(Choose Source A > B)",
    labels = scales::percent, breaks = seq(0, 1, 0.25)
  ) +
  theme(
    panel.grid = element_blank(),
    legend.position = "bottom"
  ) +
  ggtitle("Navigation problems with two correct answers")

plot_ties

if (knitting) {
  ggsave(
    here("figures", str_c("navigation_heldout_day1", ".pdf")),
    plot = plot_ties,
    width = 8/2, height = 2.5,
    units = "in", dpi = 300
  )
}
```


# Does navigation improve after rest?

First, some descriptives...

```{r descriptive-navigation-all}
bind_rows(nav_study1, nav_study2, nav_study3) %>%
  group_by(measurement_id, shortest_path) %>%
  summarise(accuracy = mean(correct), .groups = "drop") %>%
  arrange(measurement_id, shortest_path) %>%
  pivot_wider(
    names_from = shortest_path, values_from = accuracy, names_prefix = "dist-"
  ) %>%
  kbl(
    caption = str_c(
      "<center>", "Descriptive: Navigation accuracy", "</center>"
    ),
    digits = 2
  ) %>%
  kable_styling(bootstrap_options = c("responsive"))
```

## Overnight rest

Now we want to see how navigation accuracy changes from day 1 to day 2. This pools across studies 2-3 (study 1 was only a one-day experiment).

```{r stats-day1-to-day2}
nav_day1_to_day2 <- bind_rows(nav_study2, nav_study3) %>%
  filter(measurement_id %in% c("D1", "D2")) %>%
  # Give every subject a distinct identifier
  mutate(sub_id = str_c(study, " s", sub_id))

stats_nav_day2_dist2 <- nav_day1_to_day2 %>%
  mutate(shortest_path = fct_relevel(shortest_path, "2")) %>%
  glmmTMB(
    correct ~ shortest_path * measurement_id +
      (1 + shortest_path + measurement_id | sub_id) + (1 | study),
    family = binomial,
    data = .
  )

stats_nav_day2_dist3 <- nav_day1_to_day2 %>%
  mutate(shortest_path = fct_relevel(shortest_path, "3")) %>%
  glmmTMB(
    correct ~ shortest_path * measurement_id +
      (1 + shortest_path + measurement_id | sub_id) + (1 | study),
    family = binomial,
    data = .
  )

stats_nav_day2_dist4 <- nav_day1_to_day2 %>%
  mutate(shortest_path = fct_relevel(shortest_path, "4")) %>%
  glmmTMB(
    correct ~ shortest_path * measurement_id +
      (1 + shortest_path + measurement_id | sub_id) + (1 | study),
    family = binomial,
    data = .
  )

map_dfr(
  .x = list(
    "dist-2" = stats_nav_day2_dist2,
    "dist-3" = stats_nav_day2_dist3,
    "dist-4" = stats_nav_day2_dist4
  ),
  .f = ~tidy(.x, conf.int = TRUE),
  .id = "ref_cat"
) %>%
  check_significance() %>%
  select(-c(ref_cat, effect, component)) %>%
  kbl(
    caption = str_c(
      "<center>", "Navigation accuracy: Day 1 to Day 2", "</center>"
    ),
    digits = 3
  ) %>%
  kable_styling(bootstrap_options = c("responsive")) %>%
  pack_rows("Ref. Cat. dist-2", 1, 17) %>%
  pack_rows("Ref. Cat. dist-3", 18, 34) %>%
  pack_rows("Ref. Cat. dist-4", 35, 51)
```

```{r plot-day1-to-day2}
predict_nav_day1_to_day2 <- expand_grid(
  measurement_id = c("D1", "D2"),
  shortest_path = factor(2:4),
  sub_id = NA, study = NA
) %>%
  predict_glmmTMB(stats_nav_day2_dist2)

plot_nav_day1_to_day2 <- nav_day1_to_day2 %>%
  group_by(sub_id, measurement_id, shortest_path) %>%
  summarise(accuracy = mean(correct), .groups = "drop") %>%
  ggplot(aes(x=shortest_path, y=accuracy, color=measurement_id)) +
  theme_custom() +
  geom_hline(yintercept = 0.5, linetype = "dashed") +
  geom_dotplot(
    aes(fill = measurement_id),
    binwidth = 0.01,
    binaxis = "y", stackdir = "center",
    position = position_dodge(width = 0.75),
    dotsize = 1, alpha = 0.5, color = NA,
    show.legend = FALSE
  ) +
  geom_pointrange(
    aes(
      x = shortest_path, y = fit,
      ymin = fit - se.fit, ymax = fit + se.fit,
      color = measurement_id
    ),
    data = predict_nav_day1_to_day2, inherit.aes = FALSE, show.legend = FALSE,
    position = position_dodge(width = 0.15), linewidth = 1
  ) +
  geom_line(
    aes(
      x = shortest_path, y = fit,
      group = measurement_id, color = measurement_id
    ),
    data = predict_nav_day1_to_day2, inherit.aes = FALSE,
    position = position_dodge(width = 0.15), linewidth = 1
  ) +
  scale_x_discrete(name = "Shortest path distance") +
  scale_y_continuous(
    name = "Accuracy", labels = scales::percent, breaks = seq(0, 1, 0.25)
  ) +
  scale_color_manual(
    name = "Measurement",
    values = c("D1"="#fa9fb5", "D2"="#7a0177"),
    labels = c("D1"="Before overnight rest", "D2"="After overnight rest")
  ) +
  scale_fill_manual(values = c("D1"="#fa9fb5", "D2"="#7a0177")) +
  coord_cartesian(ylim = c(0.3, 1.1)) +
  theme(legend.position = "bottom") +
  ggtitle("Navigation after overnight rest")

plot_nav_day1_to_day2

if (knitting) {
  ggsave(
    here("figures", str_c("navigation_day1_to_day2", ".pdf")),
    plot = plot_nav_day1_to_day2,
    width = 4, height = 3,
    units = "in", dpi = 300
  )
}
```

## Awake rest

Is a brief period of awake rest sufficient for improving navigation accuracy?

```{r stats-awake-rest}
stats_nav_awake_dist2 <- nav_study3 %>%
  filter(measurement_id %in% c("D1", "D1b")) %>%
  mutate(shortest_path = fct_relevel(shortest_path, "2")) %>%
  glmmTMB(
    correct ~ shortest_path * measurement_id +
      (1 + shortest_path + measurement_id | sub_id),
    family = binomial,
    data = .
  )

stats_nav_awake_dist3 <- nav_study3 %>%
  filter(measurement_id %in% c("D1", "D1b")) %>%
  mutate(shortest_path = fct_relevel(shortest_path, "3")) %>%
  glmmTMB(
    correct ~ shortest_path * measurement_id +
      (1 + shortest_path + measurement_id | sub_id),
    family = binomial,
    data = .
  )

stats_nav_awake_dist4 <- nav_study3 %>%
  filter(measurement_id %in% c("D1", "D1b")) %>%
  mutate(shortest_path = fct_relevel(shortest_path, "4")) %>%
  glmmTMB(
    correct ~ shortest_path * measurement_id +
      (1 + shortest_path + measurement_id | sub_id),
    family = binomial,
    data = .
  )

map_dfr(
  .x = list(
    "dist-2" = stats_nav_awake_dist2,
    "dist-3" = stats_nav_awake_dist3,
    "dist-4" = stats_nav_awake_dist4
  ),
  .f = ~tidy(.x, conf.int = TRUE),
  .id = "ref_cat"
) %>%
  check_significance() %>%
  select(-c(ref_cat, effect, component)) %>%
  kbl(
    caption = str_c("<center>", "Navigation accuracy: Awake Rest", "</center>"),
    digits = 3
  ) %>%
  kable_styling(bootstrap_options = c("responsive")) %>%
  pack_rows("Ref. Cat. dist-2", 1, 16) %>%
  pack_rows("Ref. Cat. dist-3", 17, 32) %>%
  pack_rows("Ref. Cat. dist-4", 33, 48)
```

```{r plot-awake-rest}
predict_nav_awake <- expand_grid(
  measurement_id = c("D1", "D1b"),
  shortest_path = factor(2:4),
  sub_id = NA, study = NA
) %>%
  predict_glmmTMB(stats_nav_awake_dist2)

plot_nav_awake <- nav_study3 %>%
  filter(measurement_id %in% c("D1", "D1b")) %>%
  group_by(sub_id, measurement_id, shortest_path) %>%
  summarise(accuracy = mean(correct), .groups = "drop") %>%
  ggplot(aes(x=shortest_path, y=accuracy, color=measurement_id)) +
  theme_custom() +
  geom_hline(yintercept = 0.5, linetype = "dashed") +
  geom_dotplot(
    aes(fill = measurement_id),
    binwidth = 0.01,
    binaxis = "y", stackdir = "center",
    position = position_dodge(width = 0.75),
    dotsize = 1, alpha = 0.5, color = NA,
    show.legend = FALSE
  ) +
  geom_pointrange(
    aes(
      x = shortest_path, y = fit,
      ymin = fit - se.fit, ymax = fit + se.fit,
      color = measurement_id
    ),
    data = predict_nav_awake, inherit.aes = FALSE, show.legend = FALSE,
    position = position_dodge(width = 0.15), linewidth = 1
  ) +
  geom_line(
    aes(
      x = shortest_path, y = fit,
      group = measurement_id, color = measurement_id
    ),
    data = predict_nav_awake, inherit.aes = FALSE,
    position = position_dodge(width = 0.15), linewidth = 1
  ) +
  scale_x_discrete(name = "Shortest path distance") +
  scale_y_continuous(
    name = "Accuracy", labels = scales::percent, breaks = seq(0, 1, 0.25)
  ) +
  scale_color_manual(
    name = "Measurement",
    values = c("D1"="#fa9fb5", "D1b"="#2c7fb8"),
    labels = c("D1"="Before overnight rest", "D1b"="After awake rest")
  ) +
  scale_fill_manual(values = c("D1"="#fa9fb5", "D1b"="#2c7fb8")) +
  coord_cartesian(ylim = c(0.3, 1.1)) +
  theme(legend.position = "bottom") +
  ggtitle("Navigation after awake rest")

plot_nav_awake

if (knitting) {
  ggsave(
    here("figures", str_c("navigation_awake_rest", ".pdf")),
    plot = plot_nav_awake,
    width = 4, height = 3,
    units = "in", dpi = 300
  )
}
```

## SR replay simulation

Before starting to do any simulation, we first want to know how much replay an agent can fit into different periods of time. Based on past research measuring neural replay events, we'll assume that it takes about 50ms for the brain to replay a single item from a sequence. This will let us calculate the total number of items that could be replayed in each epoch. To simplify the process of actually running the simulation, we'll convert this quantity into the number of "sets" that can be replayed, where a single "set" consists of the total number of relationships in the network (i.e., the number of undirected edges = 17, meaning that each set consists of 34 observations because each set contains observations of both A->B and B->A).

```{r avail-sr-replay-time}
tibble(minutes_for_replay = c(1, 5, 15, 30, 60, 120)) %>%
  mutate(
    n_items = minutes_for_replay / (0.05 / 60),
    n_sets = n_items / 34
  ) %>%
  kbl(
    caption = str_c("<center>", "SR replay time", "</center>"),
    digits = 2
  ) %>%
  kable_styling(bootstrap_options = c("responsive"))
```

```{r simulate-sr-replay}
sim_replay_0_min <- simulated_paired_associates %>%
  filter(set <= 6) %>%
  simulate_sr() %>%
  mutate(replay_time = 0)

sim_replay_1_min <- simulated_paired_associates %>%
  filter(set <= 35 + 6) %>%
  simulate_sr() %>%
  mutate(replay_time = 1)

sim_replay_5_min <- simulated_paired_associates %>%
  filter(set <= 176 + 6) %>%
  simulate_sr() %>%
  mutate(replay_time = 5)

sim_replay_60_min <- simulated_paired_associates %>%
  filter(set <= 2117 + 6) %>%
  simulate_sr() %>%
  mutate(replay_time = 60)
```

```{r plot-sr-sim-replay}
plot_sim_sr_replay <- bind_rows(
  join_sr(triallist_nav_learned, sim_replay_0_min),
  join_sr(triallist_nav_learned, sim_replay_1_min),
  join_sr(triallist_nav_learned, sim_replay_5_min),
  join_sr(triallist_nav_learned, sim_replay_60_min)
) %>%
  # Feed values through softmax
  mutate(across(c(opt1_sr, opt2_sr), ~.x * 100)) %>%
  expand_grid(temperature = 1) %>%
  rowwise() %>%
  mutate(
    p_correct = softmax(
      option_values = c(opt1_sr, opt2_sr),
      option_chosen = if_else(correct_choice == opt1_id, 1, 2),
      temperature = temperature,
      use_inverse_temperature = TRUE
    )
  ) %>%
  ungroup() %>%
  # For plotting
  mutate(
    gamma = factor(gamma),
    temperature = str_pad(temperature, width = 2, side = "left"),
    temperature = str_c(unicode_greek["tau"], " = ", temperature),
    replay_time = str_pad(replay_time, width = 2, side = "left"),
    replay_time = case_when(
      str_detect(replay_time, "1$") ~ str_c(replay_time, " minute of replay"),
      replay_time == " 0" ~ "No replay",
      TRUE ~ str_c(replay_time, " minutes of replay")
    ),
    replay_time = fct_relevel(replay_time, "No replay")
  ) %>%
  group_by(replay_time, gamma, temperature, shortest_path) %>%
  summarise(p_correct = mean(p_correct), .groups = "drop") %>%
  # Plot
  ggplot(aes(x=shortest_path, y=p_correct, color=gamma)) +
  theme_custom() +
  facet_grid(cols = vars(replay_time), rows = vars(temperature)) +
  geom_hline(yintercept = 0.5, linetype = "dashed") +
  geom_line(aes(color = gamma, group = gamma), linewidth = 0.8) +
  scale_x_discrete(name = "Shortest path distance") +
  scale_y_continuous(name = "Accuracy", labels = scales::percent) +
  scale_color_viridis_d(
    name = str_c(unicode_greek["gamma"], " = "), option = "turbo", end = 0.9
  ) +
  guides(color = guide_legend(byrow = TRUE, nrow = 1)) +
  coord_cartesian(ylim = c(0.5, 1)) +
  theme(legend.position = "bottom") +
  ggtitle("Simulated effects of SR replay on navigation")

plot_sim_sr_replay

if (knitting) {
  ggsave(
    here("figures", str_c("simulated_replay", ".pdf")),
    plot = plot_sim_sr_replay,
    width = 7, height = 2.5,
    units = "in", dpi = 300,
    device = cairo_pdf
  )
}
```

## PPC before/after rest

```{r plot-ppc-day1-day2}
plot_ppc_day1_to_day2 <- ppc_sr_behavior %>%
  filter(study %in% c("Study 2", "Study 3")) %>%
  mutate(
    measurement_id = case_when(
      measurement_id == "D1" ~ "Before rest",
      measurement_id == "D2" ~ "After rest"
    ),
    measurement_id = fct_relevel(measurement_id, "Before rest")
  ) %>%
  pivot_longer(c(empirical, predicted)) %>%
  ggplot(aes(x=shortest_path, y=value)) +
  theme_custom() +
  facet_wrap(~measurement_id) +
  geom_hline(yintercept = 0.5, linetype = "dashed") +
  geom_point(
    aes(color = name),
    alpha = 0.05,
    position = position_jitterdodge(
      jitter.width = 0.2, jitter.height = 0, dodge.width = 0.5, seed = 1
    ),
    show.legend = FALSE
  ) +
  stat_summary(
    aes(color = name), geom = "crossbar", fun = mean,
    position = position_dodge(0.5)
  ) +
  scale_x_discrete(name = "Shortest path distance") +
  scale_y_continuous(
    name = "Accuracy", labels = scales::percent, breaks = seq(0, 1, 0.25)
  ) +
  scale_color_manual(
    name = NULL,
    values = c("empirical"="#fd8d3c", "predicted"="#bd0026"),
    labels = c("empirical"="Human", "predicted"="SR")
  ) +
  theme(legend.position = "bottom") +
  ggtitle("Posterior predictive check")

plot_ppc_day1_to_day2

if (knitting) {
  ggsave(
    here("figures", str_c("ppc_day2", ".pdf")),
    plot = plot_ppc_day1_to_day2,
    width = 8/3, height = 2.5,
    units = "in", dpi = 300
  )
}
```


## Relating model parameters to navigation

Does estimated gamma significantly increase after overnight rest?

```{r stats-gamma-increase}
params %>%
  filter(
    study %in% c("Study 2", "Study 3"),
    measurement_id %in% c("D1", "D2"),
    model == "sr",
    param_name == "sr_gamma"
  ) %>%
  select(study, sub_id, measurement_id, sr_gamma = param_value) %>%
  group_by(measurement_id) %>%
  summarise(median_gamma = median(sr_gamma)) %>%
  kbl(
    caption = str_c(
      "<center>", "Median SR gamma before/after overnight rest", "</center>"
    ),
    digits = 3
  ) %>%
  kable_styling(bootstrap_options = c("responsive"))

params %>%
  filter(
    study %in% c("Study 2", "Study 3"),
    measurement_id %in% c("D1", "D2"),
    model == "sr",
    param_name == "sr_gamma"
  ) %>%
  select(study, sub_id, measurement_id, sr_gamma = param_value) %>%
  pivot_wider(names_from = measurement_id, values_from = sr_gamma) %>%
  with(
    wilcox.test(
      D2, D1, alternative = "greater", paired = TRUE, conf.int = TRUE
    )
  ) %>%
  tidy() %>%
  kbl(
    caption = str_c(
      "<center>", "Increase in SR gamma after overnight rest", "</center>"
    ),
    digits = 3
  ) %>%
  kable_styling(bootstrap_options = c("responsive"))
```

```{r plot-gamma-increase}
plot_gamma_change <- params %>%
  filter(
    study %in% c("Study 2", "Study 3"),
    measurement_id %in% c("D1", "D2"),
    model == "sr",
    param_name == "sr_gamma"
  ) %>%
  select(study, sub_id, measurement_id, sr_gamma = param_value) %>%
  ggplot(aes(x=measurement_id, y=sr_gamma, color=measurement_id)) +
  theme_custom() +
  geom_point(
    position = position_jitterdodge(
      jitter.width = 0.25, jitter.height = 0, dodge.width = 0.5, seed = 1
    ),
    alpha = 0.25,
    show.legend = FALSE
  ) +
  stat_summary(
    geom = "crossbar", fun = median, position = "dodge", show.legend = FALSE
  ) +
  scale_x_discrete(
    name = "Measurement",
    labels = c("D1"="Before rest", "D2"="After rest")
  ) +
  scale_y_continuous(
    name = str_c("Estimated ", unicode_greek["gamma"]),
    breaks = seq(0, 1, 0.25)
  ) +
  scale_color_manual(
    name = "Measurement",
    values = c("D1"="#fa9fb5", "D2"="#7a0177"),
    labels = c("D1"="Before rest", "D2"="After rest")
  ) +
  coord_cartesian(ylim = c(0, 1.1)) +
  theme(legend.position = "bottom") +
  ggtitle(
    str_c(
      unicode_greek["Delta"], unicode_greek["gamma"], " after overnight rest"
    )
  )

plot_gamma_change

if (knitting) {
  ggsave(
    here("figures", str_c("gamma_change", ".pdf")),
    plot = plot_gamma_change,
    width = 8/3, height = 2.5,
    units = "in", dpi = 300,
    device = cairo_pdf
  )
}
```

Are changes in estimated gamma related to changes in navigation behaviors?

```{r stats-gamma-accuracy}
bind_rows(nav_study2, nav_study3) %>%
  filter(measurement_id %in% c("D1", "D2")) %>%
  group_by(study, sub_id, measurement_id, shortest_path) %>%
  summarise(accuracy = mean(correct), .groups = "drop") %>%
  pivot_wider(names_from = measurement_id, values_from = accuracy) %>%
  mutate(delta_accuracy = D2 - D1) %>%
  select(-c(D1, D2)) %>%
  # 
  left_join(
    params %>%
      filter(
        study %in% c("Study 2", "Study 3"),
        measurement_id %in% c("D1", "D2"),
        model == "sr",
        param_name == "sr_gamma"
      ) %>%
      select(study, sub_id, measurement_id, sr_gamma = param_value) %>%
      pivot_wider(names_from = measurement_id, values_from = sr_gamma) %>%
      mutate(delta_sr = D2 - D1) %>%
      select(-c(D1, D2)),
    by = c("study", "sub_id")
  ) %>%
  group_by(shortest_path) %>%
  nest() %>%
  mutate(
    test = map(
      .x = data,
      .f = ~with(
        .x,
        cor.test(
          delta_accuracy, delta_sr,
          method = "spearman", exact = FALSE, alternative = "greater"
        )
      ) %>% tidy()
    )
  ) %>%
  unnest(test) %>%
  ungroup() %>%
  select(-data) %>%
  kbl(
    caption = str_c(
      "<center>", "∆ Accuracy ~ ∆ SR gamma (after overnight rest)", "</center>"
    ),
    digits = 3
  ) %>%
  kable_styling(bootstrap_options = c("responsive"))
```

```{r plot-gamma-accuracy}
plot_gamma_accuracy_change <- bind_rows(nav_study2, nav_study3) %>%
  filter(measurement_id %in% c("D1", "D2")) %>%
  group_by(study, sub_id, measurement_id, shortest_path) %>%
  summarise(accuracy = mean(correct), .groups = "drop") %>%
  pivot_wider(names_from = measurement_id, values_from = accuracy) %>%
  mutate(delta_accuracy = D2 - D1) %>%
  select(-c(D1, D2)) %>%
  # 
  left_join(
    params %>%
      filter(
        study %in% c("Study 2", "Study 3"),
        measurement_id %in% c("D1", "D2"),
        model == "sr",
        param_name == "sr_gamma"
      ) %>%
      select(study, sub_id, measurement_id, sr_gamma = param_value) %>%
      pivot_wider(names_from = measurement_id, values_from = sr_gamma) %>%
      mutate(delta_sr = D2 - D1) %>%
      select(-c(D1, D2)),
    by = c("study", "sub_id")
  ) %>%
  ggplot(aes(x=delta_sr, y=delta_accuracy, color=shortest_path)) +
  theme_custom() +
  geom_hline(yintercept = 0, linetype = "dashed") +
  geom_vline(xintercept = 0, linetype = "dashed") +
  geom_point(alpha = 0.25, show.legend = FALSE) +
  geom_smooth(method = "lm", se = FALSE, linewidth = 1.5) +
  scale_x_continuous(name = str_c("Change in ", unicode_greek["gamma"])) +
  scale_y_continuous(name = "Change in accuracy") +
  scale_color_manual(
    name = "Shortest path distance",
    values = c("#88CCEE", "#CC6677", "#DDCC77")
  ) +
  coord_cartesian(xlim = c(-1, 1.1)) +
  theme(legend.position = "bottom") +
  ggtitle(
    str_c(
      unicode_greek["Delta"], "Navigation ~ ",
      unicode_greek["Delta"], unicode_greek["gamma"]
    )
  )

plot_gamma_accuracy_change

if (knitting) {
  suppressWarnings(
    ggsave(
      here("figures", str_c("gamma_accuracy_change", ".pdf")),
      plot = plot_gamma_accuracy_change,
      width = 8/3, height = 2.5,
      units = "in", dpi = 300
    )
  )
  
  ggsave(
    here("figures", str_c("gamma_accuracy_change_cairo", ".pdf")),
    plot = plot_gamma_accuracy_change,
    width = 8/3, height = 2.5,
    units = "in", dpi = 300,
    device = cairo_pdf
  )
}
```


# Evidence of cached representation

Here, we're trying to get a sense for what is being cached, and trying to see if there's evidence that caching (as opposed to model-based planning) is the primary driver of navigation improvement after overnight rest.

## Simulation visualization

```{r plot-network-with-sr}
butterfly_layout <- create_layout(g, layout = "stress")

plot_network_with_sr <- sim_sr_matrix_asymptotic %>%
  mutate(
    from_sorted = if_else(from < to, from, to),
    to_sorted = if_else(from < to, to, from)
  ) %>%
  group_by(gamma, from = from_sorted, to = to_sorted) %>%
  summarise(sr_value = mean(sr_value)) %>%
  filter(round(gamma, 1) %in% c(0.1, 0.5, 0.9)) %>%
  filter(sr_value > 0.05) %>%
  left_join(adjlist) %>%
  filter(from < to) %>%
  mutate(
    edge = factor(edge),
    gamma = str_c(unicode_greek["gamma"], " = ", gamma)
  ) %>%
  tbl_graph(edges = ., directed = FALSE) %>%
  mutate(name = row_number()) %>%
  ggraph("manual", x=butterfly_layout$x, y=butterfly_layout$y) +
  theme_network() +
  facet_edges(~gamma, ncol = 1) +
  geom_edge_link(aes(alpha = sr_value, color = edge)) +
  geom_node_label(aes(label = name)) +
  scale_edge_color_manual(
    name = NULL,
    values = c("0"="red", "1"="black"),
    labels = c("0"="Inferred connections", "1"="Observed connections")
  ) +
  scale_edge_alpha(name = "p(Target | Source)") +
  theme(
    legend.position = "bottom",
    legend.box = "vertical",
    strip.background = element_blank(),
    strip.text = element_text(size = 13)
  ) +
  ggtitle("Cognitive maps predicted by SR")

plot_network_with_sr

if (knitting) {
  suppressWarnings(
    ggsave(
      here("figures", str_c("network_with_sr", ".pdf")),
      plot = plot_network_with_sr,
      width = 3.5, height = 6,
      units = "in", dpi = 300
    )
  )
  
  ggsave(
    here("figures", str_c("network_with_sr_cairo", ".pdf")),
    plot = plot_network_with_sr,
    width = 3.5, height = 6,
    units = "in", dpi = 300,
    device = cairo_pdf
  )
}
```

## Transition reevaluation

Following transition reevaluation, a caching account predicts that people's navigation should get worse relative to their performance after overnight rest. In contrast, a planning account predicts that people's navigation should not be greatly impacted by a relatively small set of changes.

```{r stats-reevaluation}
stats_nav_tr_dist2 <- bind_rows(nav_study2, nav_study3) %>%
  filter(measurement_id %in% c("D1", "D2", "D2b")) %>%
  mutate(
    measurement_id = fct_relevel(measurement_id, "D2b"),
    shortest_path = fct_relevel(shortest_path, "2")
  ) %>%
  glmmTMB(
    correct ~ shortest_path * measurement_id +
      (1 + shortest_path + measurement_id | sub_id) + (1 | study),
    family = binomial,
    data = .
  )

stats_nav_tr_dist3 <- bind_rows(nav_study2, nav_study3) %>%
  filter(measurement_id %in% c("D1", "D2", "D2b")) %>%
  mutate(
    measurement_id = fct_relevel(measurement_id, "D2b"),
    shortest_path = fct_relevel(shortest_path, "3")
  ) %>%
  glmmTMB(
    correct ~ shortest_path * measurement_id +
      (1 + shortest_path + measurement_id | sub_id) + (1 | study),
    family = binomial,
    data = .
  )

stats_nav_tr_dist4 <- bind_rows(nav_study2, nav_study3) %>%
  filter(measurement_id %in% c("D1", "D2", "D2b")) %>%
  mutate(
    measurement_id = fct_relevel(measurement_id, "D2b"),
    shortest_path = fct_relevel(shortest_path, "4")
  ) %>%
  glmmTMB(
    correct ~ shortest_path * measurement_id +
      (1 + shortest_path + measurement_id | sub_id) + (1 | study),
    family = binomial,
    data = .
  )

map_dfr(
  .x = list(
    "dist-2" = stats_nav_tr_dist2,
    "dist-3" = stats_nav_tr_dist3,
    "dist-4" = stats_nav_tr_dist4
  ),
  .f = ~tidy(.x, conf.int = TRUE),
  .id = "ref_cat"
) %>%
  check_significance() %>%
  select(-c(ref_cat, effect, component)) %>%
  kbl(
    caption = str_c(
      "<center>", "Navigation accuracy: Transition Reevaluation", "</center>"
    ),
    digits = 3
  ) %>%
  kable_styling(bootstrap_options = c("responsive")) %>%
  pack_rows("Ref. Cat. dist-2", 1, 25) %>%
  pack_rows("Ref. Cat. dist-3", 26, 50) %>%
  pack_rows("Ref. Cat. dist-4", 51, 75)
```

```{r plot-tr}
predict_tr <- expand_grid(
  measurement_id = c("D1", "D2", "D2b"),
  shortest_path = factor(2:4),
  sub_id = NA, study = NA
) %>%
  predict_glmmTMB(stats_nav_tr_dist2)

plot_nav_tr <- nav_study3 %>%
  filter(measurement_id %in% c("D1", "D2", "D2b")) %>%
  group_by(sub_id, measurement_id, shortest_path) %>%
  summarise(accuracy = mean(correct), .groups = "drop") %>%
  ggplot(aes(x=shortest_path, y=accuracy, color=measurement_id)) +
  theme_custom() +
  geom_hline(yintercept = 0.5, linetype = "dashed") +
  geom_dotplot(
    aes(fill = measurement_id),
    binwidth = 0.01,
    binaxis = "y", stackdir = "center",
    position = position_dodge(width = 0.75),
    dotsize = 1, alpha = 0.5, color = NA,
    show.legend = FALSE
  ) +
  geom_pointrange(
    aes(
      x = shortest_path, y = fit,
      ymin = fit - se.fit, ymax = fit + se.fit,
      color = measurement_id
    ),
    data = predict_tr, inherit.aes = FALSE, show.legend = FALSE,
    position = position_dodge(width = 0.25), linewidth = 1
  ) +
  geom_line(
    aes(
      x = shortest_path, y = fit,
      group = measurement_id, color = measurement_id
    ),
    data = predict_tr, inherit.aes = FALSE,
    position = position_dodge(width = 0.25), linewidth = 1
  ) +
  scale_x_discrete(name = "Shortest path distance") +
  scale_y_continuous(
    name = "Accuracy", labels = scales::percent, breaks = seq(0, 1, 0.25)
  ) +
  scale_color_manual(
    name = "Measurement",
    values = c("D1"="#fa9fb5", "D2"="#7a0177", "D2b"="#238b45"),
    labels = c(
      "D1"="Before overnight rest",
      "D2"="After overnight rest",
      "D2b"="After reevaluation"
    )
  ) +
  scale_fill_manual(
    values = c("D1"="#fa9fb5", "D2"="#7a0177", "D2b"="#238b45")
  ) +
  coord_cartesian(ylim = c(0.3, 1.1)) +
  theme(legend.position = "bottom") +
  ggtitle("Navigation after transition reevaluation")

plot_nav_tr

if (knitting) {
  ggsave(
    here("figures", str_c("navigation_reevaluation", ".pdf")),
    plot = plot_nav_tr,
    width = 4, height = 4,
    units = "in", dpi = 300
  )
}
```


# Session info

For reproducibility.

```{r session-info}
sessionInfo()
```

